1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/distributed_runtime/master_session.h"
17
18#include <algorithm>
19#include <functional>
20#include <memory>
21#include <string>
22#include <unordered_map>
23#include <unordered_set>
24#include <utility>
25#include <vector>
26
27#include "tensorflow/core/common_runtime/process_util.h"
28#include "tensorflow/core/common_runtime/profile_handler.h"
29#include "tensorflow/core/common_runtime/stats_publisher_interface.h"
30#include "tensorflow/core/debug/debug_graph_utils.h"
31#include "tensorflow/core/distributed_runtime/request_id.h"
32#include "tensorflow/core/distributed_runtime/scheduler.h"
33#include "tensorflow/core/distributed_runtime/worker_cache.h"
34#include "tensorflow/core/distributed_runtime/worker_interface.h"
35#include "tensorflow/core/framework/allocation_description.pb.h"
36#include "tensorflow/core/framework/collective.h"
37#include "tensorflow/core/framework/cost_graph.pb.h"
38#include "tensorflow/core/framework/graph_def_util.h"
39#include "tensorflow/core/framework/node_def.pb.h"
40#include "tensorflow/core/framework/node_def_util.h"
41#include "tensorflow/core/framework/tensor.h"
42#include "tensorflow/core/framework/tensor.pb.h"
43#include "tensorflow/core/framework/tensor_description.pb.h"
44#include "tensorflow/core/graph/graph_partition.h"
45#include "tensorflow/core/graph/tensor_id.h"
46#include "tensorflow/core/lib/core/notification.h"
47#include "tensorflow/core/lib/core/refcount.h"
48#include "tensorflow/core/lib/core/status.h"
49#include "tensorflow/core/lib/gtl/cleanup.h"
50#include "tensorflow/core/lib/gtl/inlined_vector.h"
51#include "tensorflow/core/lib/gtl/map_util.h"
52#include "tensorflow/core/lib/random/random.h"
53#include "tensorflow/core/lib/strings/numbers.h"
54#include "tensorflow/core/lib/strings/str_util.h"
55#include "tensorflow/core/lib/strings/strcat.h"
56#include "tensorflow/core/lib/strings/stringprintf.h"
57#include "tensorflow/core/platform/blocking_counter.h"
58#include "tensorflow/core/platform/env.h"
59#include "tensorflow/core/platform/logging.h"
60#include "tensorflow/core/platform/macros.h"
61#include "tensorflow/core/platform/mutex.h"
62#include "tensorflow/core/platform/tracing.h"
63#include "tensorflow/core/protobuf/config.pb.h"
64#include "tensorflow/core/protobuf/coordination_config.pb.h"
65#include "tensorflow/core/public/session_options.h"
66#include "tensorflow/core/util/device_name_utils.h"
67
68namespace tensorflow {
69
70// MasterSession wraps ClientGraph in a reference counted object.
71// This way, MasterSession can clear up the cache mapping Run requests to
72// compiled graphs while the compiled graph is still being used.
73//
74// TODO(zhifengc): Cleanup this class. It's becoming messy.
75class MasterSession::ReffedClientGraph : public core::RefCounted {
76 public:
77 ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
78 std::unique_ptr<ClientGraph> client_graph,
79 const SessionOptions& session_opts,
80 const StatsPublisherFactory& stats_publisher_factory,
81 bool is_partial, WorkerCacheInterface* worker_cache,
82 bool should_deregister)
83 : session_handle_(handle),
84 bg_opts_(bopts),
85 client_graph_before_register_(std::move(client_graph)),
86 session_opts_(session_opts),
87 is_partial_(is_partial),
88 callable_opts_(bopts.callable_options),
89 worker_cache_(worker_cache),
90 should_deregister_(should_deregister),
91 collective_graph_key_(
92 client_graph_before_register_->collective_graph_key) {
93 VLOG(1) << "Created ReffedClientGraph for node with "
94 << client_graph_before_register_->graph.num_node_ids();
95
96 stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);
97
98 // Initialize a name to node map for processing device stats.
99 for (Node* n : client_graph_before_register_->graph.nodes()) {
100 name_to_node_details_.emplace(
101 n->name(),
102 NodeDetails(n->type_string(),
103 strings::StrCat(
104 "(", absl::StrJoin(n->requested_inputs(), ", "))));
105 }
106 }
107
108 ~ReffedClientGraph() override {
109 if (should_deregister_) {
110 DeregisterPartitions();
111 } else {
112 for (Part& part : partitions_) {
113 worker_cache_->ReleaseWorker(part.name, part.worker);
114 }
115 }
116 }
117
118 const CallableOptions& callable_options() { return callable_opts_; }
119
120 const BuildGraphOptions& build_graph_options() { return bg_opts_; }
121
122 int64_t collective_graph_key() { return collective_graph_key_; }
123
124 std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step,
125 int64_t execution_count,
126 const RunOptions& ropts) {
127 return stats_publisher_->GetProfileHandler(step, execution_count, ropts);
128 }
129
130 int64_t get_and_increment_execution_count() {
131 return execution_count_.fetch_add(1);
132 }
133
134 // Turn RPC logging on or off, both at the WorkerCache used by this
135 // master process, and at each remote worker in use for the current
136 // partitions.
137 void SetRPCLogging(bool active) {
138 worker_cache_->SetLogging(active);
139 // Logging is a best-effort activity, so we make async calls to turn
140 // it on/off and don't make use of the responses.
141 for (auto& p : partitions_) {
142 LoggingRequest* req = new LoggingRequest;
143 if (active) {
144 req->set_enable_rpc_logging(true);
145 } else {
146 req->set_disable_rpc_logging(true);
147 }
148 LoggingResponse* resp = new LoggingResponse;
149 Ref();
150 p.worker->LoggingAsync(req, resp, [this, req, resp](const Status& s) {
151 delete req;
152 delete resp;
153 // ReffedClientGraph owns p.worker so we need to hold a ref to
154 // ensure that the method doesn't attempt to access p.worker after
155 // ReffedClient graph has deleted it.
156 // TODO(suharshs): Simplify this ownership model.
157 Unref();
158 });
159 }
160 }
161
162 // Retrieve all RPC logs data accumulated for the current step, both
163 // from the local WorkerCache in use by this master process and from
164 // all the remote workers executing the remote partitions.
165 void RetrieveLogs(int64_t step_id, StepStats* ss) {
166 // Get the local data first, because it sets *ss without merging.
167 worker_cache_->RetrieveLogs(step_id, ss);
168
169 // Then merge in data from all the remote workers.
170 LoggingRequest req;
171 req.add_fetch_step_id(step_id);
172 int waiting_for = partitions_.size();
173 if (waiting_for > 0) {
174 mutex scoped_mu;
175 BlockingCounter all_done(waiting_for);
176 for (auto& p : partitions_) {
177 LoggingResponse* resp = new LoggingResponse;
178 p.worker->LoggingAsync(
179 &req, resp,
180 [step_id, ss, resp, &scoped_mu, &all_done](const Status& s) {
181 {
182 mutex_lock l(scoped_mu);
183 if (s.ok()) {
184 for (auto& lss : resp->step()) {
185 if (step_id != lss.step_id()) {
186 LOG(ERROR) << "Wrong step_id in LoggingResponse";
187 continue;
188 }
189 ss->MergeFrom(lss.step_stats());
190 }
191 }
192 delete resp;
193 }
194 // Must not decrement all_done until out of critical section where
195 // *ss is updated.
196 all_done.DecrementCount();
197 });
198 }
199 all_done.Wait();
200 }
201 }
202
203 // Local execution methods.
204
205 // Partitions the graph into subgraphs and registers them on
206 // workers.
207 Status RegisterPartitions(PartitionOptions popts);
208
209 // Runs one step of all partitions.
210 Status RunPartitions(const MasterEnv* env, int64_t step_id,
211 int64_t execution_count, PerStepState* pss,
212 CallOptions* opts, const RunStepRequestWrapper& req,
213 MutableRunStepResponseWrapper* resp,
214 CancellationManager* cm, const bool is_last_partial_run);
215 Status RunPartitions(const MasterEnv* env, int64_t step_id,
216 int64_t execution_count, PerStepState* pss,
217 CallOptions* call_opts, const RunCallableRequest& req,
218 RunCallableResponse* resp, CancellationManager* cm);
219
220 // Calls workers to cleanup states for the step "step_id". Calls
221 // `done` when all cleanup RPCs have completed.
222 void CleanupPartitionsAsync(int64_t step_id, StatusCallback done);
223
224 // Post-processing of any runtime statistics gathered during execution.
225 void ProcessStats(int64_t step_id, PerStepState* pss, ProfileHandler* ph,
226 const RunOptions& options, RunMetadata* resp);
227 void ProcessDeviceStats(ProfileHandler* ph, const DeviceStepStats& ds,
228 bool is_rpc);
229 // Checks that the requested fetches can be computed from the provided feeds.
230 Status CheckFetches(const RunStepRequestWrapper& req,
231 const RunState* run_state,
232 GraphExecutionState* execution_state);
233
234 private:
235 const string session_handle_;
236 const BuildGraphOptions bg_opts_;
237
238 // NOTE(mrry): This pointer will be null after `RegisterPartitions()` returns.
239 std::unique_ptr<ClientGraph> client_graph_before_register_ TF_GUARDED_BY(mu_);
240 const SessionOptions session_opts_;
241 const bool is_partial_;
242 const CallableOptions callable_opts_;
243 WorkerCacheInterface* const worker_cache_; // Not owned.
244
245 struct NodeDetails {
246 explicit NodeDetails(string type_string, string detail_text)
247 : type_string(std::move(type_string)),
248 detail_text(std::move(detail_text)) {}
249 const string type_string;
250 const string detail_text;
251 };
252 std::unordered_map<string, NodeDetails> name_to_node_details_;
253
254 const bool should_deregister_;
255 const int64_t collective_graph_key_;
256 std::atomic<int64_t> execution_count_ = {0};
257
258 // Graph partitioned into per-location subgraphs.
259 struct Part {
260 // Worker name.
261 string name;
262
263 // Maps feed names to rendezvous keys. Empty most of the time.
264 std::unordered_map<string, string> feed_key;
265
266 // Maps rendezvous keys to fetch names. Empty most of the time.
267 std::unordered_map<string, string> key_fetch;
268
269 // The interface to the worker. Owned.
270 WorkerInterface* worker = nullptr;
271
272 // After registration with the worker, graph_handle identifies
273 // this partition on the worker.
274 string graph_handle;
275
276 Part() : feed_key(3), key_fetch(3) {}
277 };
278
279 // partitions_ is immutable after RegisterPartitions() call
280 // finishes. RunPartitions() can access partitions_ safely without
281 // acquiring locks.
282 std::vector<Part> partitions_;
283
284 mutable mutex mu_;
285
286 // Partition initialization and registration only needs to happen
287 // once. `!client_graph_before_register_ && !init_done_.HasBeenNotified()`
288 // indicates the initialization is ongoing.
289 Notification init_done_;
290
291 // init_result_ remembers the initialization error if any.
292 Status init_result_ TF_GUARDED_BY(mu_);
293
294 std::unique_ptr<StatsPublisherInterface> stats_publisher_;
295
296 string DetailText(const NodeDetails& details, const NodeExecStats& stats) {
297 int64_t tot = 0;
298 for (auto& no : stats.output()) {
299 tot += no.tensor_description().allocation_description().requested_bytes();
300 }
301 string bytes;
302 if (tot >= 0.1 * 1048576.0) {
303 bytes = strings::Printf("[%.1fMB] ", tot / 1048576.0);
304 }
305 return strings::StrCat(bytes, stats.node_name(), " = ", details.type_string,
306 details.detail_text);
307 }
308
309 // Send/Recv nodes that are the result of client-added
310 // feeds and fetches must be tracked so that the tensors
311 // can be added to the local rendezvous.
312 static void TrackFeedsAndFetches(Part* part, const GraphDef& graph_def,
313 const PartitionOptions& popts);
314
315 // The actual graph partitioning and registration implementation.
316 Status DoBuildPartitions(
317 PartitionOptions popts, ClientGraph* client_graph,
318 std::unordered_map<string, GraphDef>* out_partitions);
319 Status DoRegisterPartitions(
320 const PartitionOptions& popts,
321 std::unordered_map<string, GraphDef> graph_partitions);
322
323 // Prepares a number of calls to workers. One call per partition.
324 // This is a generic method that handles Run, PartialRun, and RunCallable.
325 template <class FetchListType, class ClientRequestType,
326 class ClientResponseType>
327 Status RunPartitionsHelper(
328 const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
329 const FetchListType& fetches, const MasterEnv* env, int64_t step_id,
330 int64_t execution_count, PerStepState* pss, CallOptions* call_opts,
331 const ClientRequestType& req, ClientResponseType* resp,
332 CancellationManager* cm, bool is_last_partial_run);
333
334 // Deregisters the partitions on the workers. Called in the
335 // destructor and does not wait for the rpc completion.
336 void DeregisterPartitions();
337
338 TF_DISALLOW_COPY_AND_ASSIGN(ReffedClientGraph);
339};
340
341Status MasterSession::ReffedClientGraph::RegisterPartitions(
342 PartitionOptions popts) {
343 { // Ensure register once.
344 mu_.lock();
345 if (client_graph_before_register_) {
346 // The `ClientGraph` is no longer needed after partitions are registered.
347 // Since it can account for a large amount of memory, we consume it here,
348 // and it will be freed after concluding with registration.
349
350 std::unique_ptr<ClientGraph> client_graph;
351 std::swap(client_graph_before_register_, client_graph);
352 mu_.unlock();
353 std::unordered_map<string, GraphDef> graph_defs;
354 popts.flib_def = client_graph->flib_def.get();
355 Status s = DoBuildPartitions(popts, client_graph.get(), &graph_defs);
356 if (s.ok()) {
357 // NOTE(mrry): The pointers in `graph_defs_for_publishing` do not remain
358 // valid after the call to DoRegisterPartitions begins, so
359 // `stats_publisher_` must make a copy if it wants to retain the
360 // GraphDef objects.
361 std::vector<const GraphDef*> graph_defs_for_publishing;
362 graph_defs_for_publishing.reserve(partitions_.size());
363 for (const auto& name_def : graph_defs) {
364 graph_defs_for_publishing.push_back(&name_def.second);
365 }
366 stats_publisher_->PublishGraphProto(graph_defs_for_publishing);
367 s = DoRegisterPartitions(popts, std::move(graph_defs));
368 }
369 mu_.lock();
370 init_result_ = s;
371 init_done_.Notify();
372 } else {
373 mu_.unlock();
374 init_done_.WaitForNotification();
375 mu_.lock();
376 }
377 const Status result = init_result_;
378 mu_.unlock();
379 return result;
380 }
381}
382
383static string SplitByWorker(const Node* node) {
384 string task;
385 string device;
386 CHECK(DeviceNameUtils::SplitDeviceName(node->assigned_device_name(), &task,
387 &device))
388 << "node: " << node->name() << " dev: " << node->assigned_device_name();
389 return task;
390}
391
392void MasterSession::ReffedClientGraph::TrackFeedsAndFetches(
393 Part* part, const GraphDef& graph_def, const PartitionOptions& popts) {
394 for (int i = 0; i < graph_def.node_size(); ++i) {
395 const NodeDef& ndef = graph_def.node(i);
396 const bool is_recv = ndef.op() == "_Recv";
397 const bool is_send = ndef.op() == "_Send";
398
399 if (is_recv || is_send) {
400 // Only send/recv nodes that were added as feeds and fetches
401 // (client-terminated) should be tracked. Other send/recv nodes
402 // are for transferring data between partitions / memory spaces.
403 bool client_terminated;
404 TF_CHECK_OK(GetNodeAttr(ndef, "client_terminated", &client_terminated));
405 if (client_terminated) {
406 string name;
407 TF_CHECK_OK(GetNodeAttr(ndef, "tensor_name", &name));
408 string send_device;
409 TF_CHECK_OK(GetNodeAttr(ndef, "send_device", &send_device));
410 string recv_device;
411 TF_CHECK_OK(GetNodeAttr(ndef, "recv_device", &recv_device));
412 uint64 send_device_incarnation;
413 TF_CHECK_OK(
414 GetNodeAttr(ndef, "send_device_incarnation",
415 reinterpret_cast<int64_t*>(&send_device_incarnation)));
416 const string& key =
417 Rendezvous::CreateKey(send_device, send_device_incarnation,
418 recv_device, name, FrameAndIter(0, 0));
419
420 if (is_recv) {
421 part->feed_key.insert({name, key});
422 } else {
423 part->key_fetch.insert({key, name});
424 }
425 }
426 }
427 }
428}
429
430Status MasterSession::ReffedClientGraph::DoBuildPartitions(
431 PartitionOptions popts, ClientGraph* client_graph,
432 std::unordered_map<string, GraphDef>* out_partitions) {
433 if (popts.need_to_record_start_times) {
434 CostModel cost_model(true);
435 cost_model.InitFromGraph(client_graph->graph);
436 // TODO(yuanbyu): Use the real cost model.
437 // execution_state_->MergeFromGlobal(&cost_model);
438 SlackAnalysis sa(&client_graph->graph, &cost_model);
439 sa.ComputeAsap(&popts.start_times);
440 }
441
442 // Partition the graph.
443 return Partition(popts, &client_graph->graph, out_partitions);
444}
445
446Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
447 const PartitionOptions& popts,
448 std::unordered_map<string, GraphDef> graph_partitions) {
449 partitions_.reserve(graph_partitions.size());
450 Status s;
451 for (auto& name_def : graph_partitions) {
452 partitions_.emplace_back();
453 Part* part = &partitions_.back();
454 part->name = name_def.first;
455 TrackFeedsAndFetches(part, name_def.second, popts);
456 part->worker = worker_cache_->GetOrCreateWorker(part->name);
457 if (part->worker == nullptr) {
458 s = errors::NotFound("worker ", part->name);
459 break;
460 }
461 }
462 if (!s.ok()) {
463 for (Part& part : partitions_) {
464 worker_cache_->ReleaseWorker(part.name, part.worker);
465 part.worker = nullptr;
466 }
467 return s;
468 }
469 struct Call {
470 RegisterGraphRequest req;
471 RegisterGraphResponse resp;
472 Status status;
473 };
474 const int num = partitions_.size();
475 gtl::InlinedVector<Call, 4> calls(num);
476 BlockingCounter done(num);
477 for (int i = 0; i < num; ++i) {
478 const Part& part = partitions_[i];
479 Call* c = &calls[i];
480 c->req.set_session_handle(session_handle_);
481 c->req.set_create_worker_session_called(!should_deregister_);
482 c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
483 StripDefaultAttributes(*OpRegistry::Global(),
484 c->req.mutable_graph_def()->mutable_node());
485 *c->req.mutable_config_proto() = session_opts_.config;
486 *c->req.mutable_graph_options() = session_opts_.config.graph_options();
487 *c->req.mutable_debug_options() =
488 callable_opts_.run_options().debug_options();
489 c->req.set_collective_graph_key(collective_graph_key_);
490 VLOG(2) << "Register " << c->req.graph_def().DebugString();
491 auto cb = [c, &done](const Status& s) {
492 c->status = s;
493 done.DecrementCount();
494 };
495 part.worker->RegisterGraphAsync(&c->req, &c->resp, cb);
496 }
497 done.Wait();
498 for (int i = 0; i < num; ++i) {
499 Call* c = &calls[i];
500 s.Update(c->status);
501 partitions_[i].graph_handle = c->resp.graph_handle();
502 }
503 return s;
504}
505
506namespace {
507// Helper class to manage "num" parallel RunGraph calls.
508class RunManyGraphs {
509 public:
510 explicit RunManyGraphs(int num) : calls_(num), pending_(num) {}
511
512 ~RunManyGraphs() {}
513
514 // Returns the index-th call.
515 struct Call {
516 CallOptions opts;
517 const string* worker_name;
518 std::atomic<bool> done{false};
519 std::unique_ptr<MutableRunGraphRequestWrapper> req;
520 std::unique_ptr<MutableRunGraphResponseWrapper> resp;
521 };
522 Call* get(int index) { return &calls_[index]; }
523
524 // When the index-th call is done, updates the overall status.
525 void WhenDone(int index, const Status& s) {
526 TRACEPRINTF("Partition %d %s", index, s.ToString().c_str());
527 Call* call = get(index);
528 call->done = true;
529 auto resp = call->resp.get();
530 if (resp->status_code() != error::Code::OK) {
531 // resp->status_code will only be non-OK if s.ok().
532 mutex_lock l(mu_);
533 Status resp_status = call->resp->status();
534 ReportBadStatus(errors::CreateWithUpdatedMessage(
535 resp_status, strings::StrCat("From ", *call->worker_name, ":\n",
536 resp_status.error_message())));
537 } else if (!s.ok()) {
538 mutex_lock l(mu_);
539 ReportBadStatus(errors::CreateWithUpdatedMessage(
540 s, strings::StrCat("From ", *call->worker_name, ":\n",
541 s.error_message())));
542 }
543 pending_.DecrementCount();
544 }
545
546 void StartCancel() {
547 mutex_lock l(mu_);
548 ReportBadStatus(errors::Cancelled("RunManyGraphs"));
549 }
550
551 void Wait() {
552 // Check the error status every 60 seconds in other to print a log message
553 // in the event of a hang.
554 const std::chrono::milliseconds kCheckErrorPeriod(1000 * 60);
555 while (true) {
556 if (pending_.WaitFor(kCheckErrorPeriod)) {
557 return;
558 }
559 if (!status().ok()) {
560 break;
561 }
562 }
563
564 // The step has failed. Wait for another 60 seconds before diagnosing a
565 // hang.
566 DCHECK(!status().ok());
567 if (pending_.WaitFor(kCheckErrorPeriod)) {
568 return;
569 }
570 LOG(ERROR)
571 << "RunStep still blocked after 60 seconds. Failed with error status: "
572 << status();
573 for (const Call& call : calls_) {
574 if (!call.done) {
575 LOG(ERROR) << "- No response from RunGraph call to worker: "
576 << *call.worker_name;
577 }
578 }
579 pending_.Wait();
580 }
581
582 Status status() const {
583 mutex_lock l(mu_);
584 // Concat status objects in this StatusGroup to get the aggregated status,
585 // as each status in status_group_ is already summarized status.
586 return status_group_.as_concatenated_status();
587 }
588
589 private:
590 gtl::InlinedVector<Call, 4> calls_;
591
592 BlockingCounter pending_;
593 mutable mutex mu_;
594 StatusGroup status_group_ TF_GUARDED_BY(mu_);
595 bool cancel_issued_ TF_GUARDED_BY(mu_) = false;
596
597 void ReportBadStatus(const Status& s) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
598 VLOG(1) << "Master received error status " << s;
599 if (!cancel_issued_ && !StatusGroup::IsDerived(s)) {
600 // Only start cancelling other workers upon receiving a non-derived
601 // error
602 cancel_issued_ = true;
603
604 VLOG(1) << "Master received error report. Cancelling remaining workers.";
605 for (Call& call : calls_) {
606 call.opts.StartCancel();
607 }
608 }
609
610 status_group_.Update(s);
611 }
612
613 TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
614};
615
616Status AddSendFromClientRequest(const RunStepRequestWrapper& client_req,
617 MutableRunGraphRequestWrapper* worker_req,
618 size_t index, const string& send_key) {
619 return worker_req->AddSendFromRunStepRequest(client_req, index, send_key);
620}
621
622Status AddSendFromClientRequest(const RunCallableRequest& client_req,
623 MutableRunGraphRequestWrapper* worker_req,
624 size_t index, const string& send_key) {
625 return worker_req->AddSendFromRunCallableRequest(client_req, index, send_key);
626}
627
628// TODO(mrry): Add a full-fledged wrapper that avoids TensorProto copies for
629// in-process messages.
630struct RunCallableResponseWrapper {
631 RunCallableResponse* resp; // Not owned.
632 std::unordered_map<string, TensorProto> fetch_key_to_protos;
633
634 RunMetadata* mutable_metadata() { return resp->mutable_metadata(); }
635
636 Status AddTensorFromRunGraphResponse(
637 const string& tensor_name, MutableRunGraphResponseWrapper* worker_resp,
638 size_t index) {
639 return worker_resp->RecvValue(index, &fetch_key_to_protos[tensor_name]);
640 }
641};
642} // namespace
643
644template <class FetchListType, class ClientRequestType,
645 class ClientResponseType>
646Status MasterSession::ReffedClientGraph::RunPartitionsHelper(
647 const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
648 const FetchListType& fetches, const MasterEnv* env, int64_t step_id,
649 int64_t execution_count, PerStepState* pss, CallOptions* call_opts,
650 const ClientRequestType& req, ClientResponseType* resp,
651 CancellationManager* cm, bool is_last_partial_run) {
652 // Collect execution cost stats on a smoothly decreasing frequency.
653 ExecutorOpts exec_opts;
654 if (pss->report_tensor_allocations_upon_oom) {
655 exec_opts.set_report_tensor_allocations_upon_oom(true);
656 }
657 if (pss->collect_costs) {
658 exec_opts.set_record_costs(true);
659 }
660 if (pss->collect_timeline) {
661 exec_opts.set_record_timeline(true);
662 }
663 if (pss->collect_rpcs) {
664 SetRPCLogging(true);
665 }
666 if (pss->collect_partition_graphs) {
667 exec_opts.set_record_partition_graphs(true);
668 }
669 if (pss->collect_costs || pss->collect_timeline) {
670 pss->step_stats.resize(partitions_.size());
671 }
672
673 const int num = partitions_.size();
674 RunManyGraphs calls(num);
675
676 for (int i = 0; i < num; ++i) {
677 const Part& part = partitions_[i];
678 RunManyGraphs::Call* c = calls.get(i);
679 c->worker_name = &part.name;
680 c->req.reset(part.worker->CreateRunGraphRequest());
681 c->resp.reset(part.worker->CreateRunGraphResponse());
682 if (is_partial_) {
683 c->req->set_is_partial(is_partial_);
684 c->req->set_is_last_partial_run(is_last_partial_run);
685 }
686 c->req->set_session_handle(session_handle_);
687 c->req->set_create_worker_session_called(!should_deregister_);
688 c->req->set_graph_handle(part.graph_handle);
689 c->req->set_step_id(step_id);
690 *c->req->mutable_exec_opts() = exec_opts;
691 c->req->set_store_errors_in_response_body(true);
692 c->req->set_request_id(GetUniqueRequestId());
693 // If any feeds are provided, send the feed values together
694 // in the RunGraph request.
695 // In the partial case, we only want to include feeds provided in the req.
696 // In the non-partial case, all feeds in the request are in the part.
697 // We keep these as separate paths for now, to ensure we aren't
698 // inadvertently slowing down the normal run path.
699 if (is_partial_) {
700 for (const auto& name_index : feeds) {
701 const auto iter = part.feed_key.find(string(name_index.first));
702 if (iter == part.feed_key.end()) {
703 // The provided feed must be for a different partition.
704 continue;
705 }
706 const string& key = iter->second;
707 TF_RETURN_IF_ERROR(AddSendFromClientRequest(req, c->req.get(),
708 name_index.second, key));
709 }
710 // TODO(suharshs): Make a map from feed to fetch_key to make this faster.
711 // For now, we just iterate through partitions to find the matching key.
712 for (const string& req_fetch : fetches) {
713 for (const auto& key_fetch : part.key_fetch) {
714 if (key_fetch.second == req_fetch) {
715 c->req->add_recv_key(key_fetch.first);
716 break;
717 }
718 }
719 }
720 } else {
721 for (const auto& feed_key : part.feed_key) {
722 const string& feed = feed_key.first;
723 const string& key = feed_key.second;
724 auto iter = feeds.find(feed);
725 if (iter == feeds.end()) {
726 return errors::Internal("No feed index found for feed: ", feed);
727 }
728 const int64_t feed_index = iter->second;
729 TF_RETURN_IF_ERROR(
730 AddSendFromClientRequest(req, c->req.get(), feed_index, key));
731 }
732 for (const auto& key_fetch : part.key_fetch) {
733 const string& key = key_fetch.first;
734 c->req->add_recv_key(key);
735 }
736 }
737 }
738
739 // Issues RunGraph calls.
740 for (int i = 0; i < num; ++i) {
741 const Part& part = partitions_[i];
742 RunManyGraphs::Call* call = calls.get(i);
743 TRACEPRINTF("Partition %d %s", i, part.name.c_str());
744 part.worker->RunGraphAsync(
745 &call->opts, call->req.get(), call->resp.get(),
746 std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1));
747 }
748
749 // Waits for the RunGraph calls.
750 call_opts->SetCancelCallback([&calls]() {
751 LOG(INFO) << "Client requested cancellation for RunStep, cancelling "
752 "worker operations.";
753 calls.StartCancel();
754 });
755 auto token = cm->get_cancellation_token();
756 const bool success =
757 cm->RegisterCallback(token, [&calls]() { calls.StartCancel(); });
758 if (!success) {
759 calls.StartCancel();
760 }
761 calls.Wait();
762 call_opts->ClearCancelCallback();
763 if (success) {
764 cm->DeregisterCallback(token);
765 } else {
766 return errors::Cancelled("Step was cancelled");
767 }
768 TF_RETURN_IF_ERROR(calls.status());
769
770 // Collects fetches and metadata.
771 Status status;
772 for (int i = 0; i < num; ++i) {
773 const Part& part = partitions_[i];
774 MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get();
775 for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) {
776 auto iter = part.key_fetch.find(run_graph_resp->recv_key(j));
777 if (iter == part.key_fetch.end()) {
778 status.Update(errors::Internal("Unexpected fetch key: ",
779 run_graph_resp->recv_key(j)));
780 break;
781 }
782 const string& fetch = iter->second;
783 status.Update(
784 resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j));
785 if (!status.ok()) {
786 break;
787 }
788 }
789 if (pss->collect_timeline) {
790 pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats());
791 }
792 if (pss->collect_costs) {
793 CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph();
794 for (int j = 0; j < cost_graph->node_size(); ++j) {
795 resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
796 cost_graph->mutable_node(j));
797 }
798 }
799 if (pss->collect_partition_graphs) {
800 protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
801 resp->mutable_metadata()->mutable_partition_graphs();
802 for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) {
803 partition_graph_defs->Add()->Swap(
804 run_graph_resp->mutable_partition_graph(i));
805 }
806 }
807 }
808 return status;
809}
810
811Status MasterSession::ReffedClientGraph::RunPartitions(
812 const MasterEnv* env, int64_t step_id, int64_t execution_count,
813 PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req,
814 MutableRunStepResponseWrapper* resp, CancellationManager* cm,
815 const bool is_last_partial_run) {
816 VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
817 << execution_count;
818 // Maps the names of fed tensors to their index in `req`.
819 std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
820 for (size_t i = 0; i < req.num_feeds(); ++i) {
821 if (!feeds.insert({req.feed_name(i), i}).second) {
822 return errors::InvalidArgument("Duplicated feeds: ", req.feed_name(i));
823 }
824 }
825
826 std::vector<string> fetches;
827 fetches.reserve(req.num_fetches());
828 for (size_t i = 0; i < req.num_fetches(); ++i) {
829 fetches.push_back(req.fetch_name(i));
830 }
831
832 return RunPartitionsHelper(feeds, fetches, env, step_id, execution_count, pss,
833 call_opts, req, resp, cm, is_last_partial_run);
834}
835
836Status MasterSession::ReffedClientGraph::RunPartitions(
837 const MasterEnv* env, int64_t step_id, int64_t execution_count,
838 PerStepState* pss, CallOptions* call_opts, const RunCallableRequest& req,
839 RunCallableResponse* resp, CancellationManager* cm) {
840 VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
841 << execution_count;
842 // Maps the names of fed tensors to their index in `req`.
843 std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
844 for (size_t i = 0, end = callable_opts_.feed_size(); i < end; ++i) {
845 if (!feeds.insert({callable_opts_.feed(i), i}).second) {
846 // MakeCallable will fail if there are two feeds with the same name.
847 return errors::Internal("Duplicated feeds in callable: ",
848 callable_opts_.feed(i));
849 }
850 }
851
852 // Create a wrapped response object to collect the fetched values and
853 // rearrange them for the RunCallableResponse.
854 RunCallableResponseWrapper wrapped_resp;
855 wrapped_resp.resp = resp;
856
857 TF_RETURN_IF_ERROR(RunPartitionsHelper(
858 feeds, callable_opts_.fetch(), env, step_id, execution_count, pss,
859 call_opts, req, &wrapped_resp, cm, false /* is_last_partial_run */));
860
861 // Collects fetches.
862 for (const string& fetch : callable_opts_.fetch()) {
863 TensorProto* fetch_proto = resp->mutable_fetch()->Add();
864 auto iter = wrapped_resp.fetch_key_to_protos.find(fetch);
865 if (iter == wrapped_resp.fetch_key_to_protos.end()) {
866 return errors::Internal("Worker did not return a value for fetch: ",
867 fetch);
868 }
869 fetch_proto->Swap(&iter->second);
870 }
871 return OkStatus();
872}
873
874namespace {
875
876class CleanupBroadcastHelper {
877 public:
878 CleanupBroadcastHelper(int64_t step_id, int num_calls, StatusCallback done)
879 : resps_(num_calls), num_pending_(num_calls), done_(std::move(done)) {
880 req_.set_step_id(step_id);
881 }
882
883 // Returns a non-owned pointer to a request buffer for all calls.
884 CleanupGraphRequest* request() { return &req_; }
885
886 // Returns a non-owned pointer to a response buffer for the ith call.
887 CleanupGraphResponse* response(int i) { return &resps_[i]; }
888
889 // Called when the ith response is received.
890 void call_done(int i, const Status& s) {
891 bool run_callback = false;
892 Status status_copy;
893 {
894 mutex_lock l(mu_);
895 status_.Update(s);
896 if (--num_pending_ == 0) {
897 run_callback = true;
898 status_copy = status_;
899 }
900 }
901 if (run_callback) {
902 done_(status_copy);
903 // This is the last call, so delete the helper object.
904 delete this;
905 }
906 }
907
908 private:
909 // A single request shared between all workers.
910 CleanupGraphRequest req_;
911 // One response buffer for each worker.
912 gtl::InlinedVector<CleanupGraphResponse, 4> resps_;
913
914 mutex mu_;
915 // Number of requests remaining to be collected.
916 int num_pending_ TF_GUARDED_BY(mu_);
917 // Aggregate status of the operation.
918 Status status_ TF_GUARDED_BY(mu_);
919 // Callback to be called when all operations complete.
920 StatusCallback done_;
921
922 TF_DISALLOW_COPY_AND_ASSIGN(CleanupBroadcastHelper);
923};
924
925} // namespace
926
927void MasterSession::ReffedClientGraph::CleanupPartitionsAsync(
928 int64_t step_id, StatusCallback done) {
929 const int num = partitions_.size();
930 // Helper object will be deleted when the final call completes.
931 CleanupBroadcastHelper* helper =
932 new CleanupBroadcastHelper(step_id, num, std::move(done));
933 for (int i = 0; i < num; ++i) {
934 const Part& part = partitions_[i];
935 part.worker->CleanupGraphAsync(
936 helper->request(), helper->response(i),
937 [helper, i](const Status& s) { helper->call_done(i, s); });
938 }
939}
940
941void MasterSession::ReffedClientGraph::ProcessStats(int64_t step_id,
942 PerStepState* pss,
943 ProfileHandler* ph,
944 const RunOptions& options,
945 RunMetadata* resp) {
946 if (!pss->collect_costs && !pss->collect_timeline) return;
947
948 // Out-of-band logging data is collected now, during post-processing.
949 if (pss->collect_timeline) {
950 SetRPCLogging(false);
951 RetrieveLogs(step_id, &pss->rpc_stats);
952 }
953 for (size_t i = 0; i < partitions_.size(); ++i) {
954 const StepStats& ss = pss->step_stats[i];
955 if (ph) {
956 for (const auto& ds : ss.dev_stats()) {
957 ProcessDeviceStats(ph, ds, false /*is_rpc*/);
958 }
959 }
960 }
961 if (ph) {
962 for (const auto& ds : pss->rpc_stats.dev_stats()) {
963 ProcessDeviceStats(ph, ds, true /*is_rpc*/);
964 }
965 ph->StepDone(pss->start_micros, pss->end_micros,
966 Microseconds(0) /*cleanup_time*/, 0 /*total_runops*/,
967 OkStatus());
968 }
969 // Assemble all stats for this timeline into a merged StepStats.
970 if (pss->collect_timeline) {
971 StepStats step_stats_proto;
972 step_stats_proto.Swap(&pss->rpc_stats);
973 for (size_t i = 0; i < partitions_.size(); ++i) {
974 step_stats_proto.MergeFrom(pss->step_stats[i]);
975 pss->step_stats[i].Clear();
976 }
977 pss->step_stats.clear();
978 // Copy the stats back, but only for on-demand profiling to avoid slowing
979 // down calls that trigger the automatic profiling.
980 if (options.trace_level() == RunOptions::FULL_TRACE) {
981 resp->mutable_step_stats()->Swap(&step_stats_proto);
982 } else {
983 // If FULL_TRACE, it can be fetched from Session API, no need for
984 // duplicated publishing.
985 stats_publisher_->PublishStatsProto(step_stats_proto);
986 }
987 }
988}
989
990void MasterSession::ReffedClientGraph::ProcessDeviceStats(
991 ProfileHandler* ph, const DeviceStepStats& ds, bool is_rpc) {
992 const string& dev_name = ds.device();
993 VLOG(1) << "Device " << dev_name << " reports stats for "
994 << ds.node_stats_size() << " nodes";
995 for (const auto& ns : ds.node_stats()) {
996 if (is_rpc) {
997 // We don't have access to a good Node pointer, so we rely on
998 // sufficient data being present in the NodeExecStats.
999 ph->RecordOneOp(dev_name, ns, true /*is_copy*/, "", ns.node_name(),
1000 ns.timeline_label());
1001 } else {
1002 auto iter = name_to_node_details_.find(ns.node_name());
1003 const bool found_node_in_graph = iter != name_to_node_details_.end();
1004 if (!found_node_in_graph && ns.timeline_label().empty()) {
1005 // The counter incrementing is not thread-safe. But we don't really
1006 // care.
1007 // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N for
1008 // more general usage.
1009 static int log_counter = 0;
1010 if (log_counter < 10) {
1011 log_counter++;
1012 LOG(WARNING) << "Failed to find node " << ns.node_name()
1013 << " for dev " << dev_name;
1014 }
1015 continue;
1016 }
1017 const string& optype =
1018 found_node_in_graph ? iter->second.type_string : ns.node_name();
1019 string details;
1020 if (!ns.timeline_label().empty()) {
1021 details = ns.timeline_label();
1022 } else if (found_node_in_graph) {
1023 details = DetailText(iter->second, ns);
1024 } else {
1025 // Leave details string empty
1026 }
1027 ph->RecordOneOp(dev_name, ns, false /*is_copy*/, ns.node_name(), optype,
1028 details);
1029 }
1030 }
1031}
1032
1033// TODO(suharshs): Merge with CheckFetches in DirectSession.
1034// TODO(suharsh,mrry): Build a map from fetch target to set of feeds it depends
1035// on once at setup time to prevent us from computing the dependencies
1036// everytime.
1037Status MasterSession::ReffedClientGraph::CheckFetches(
1038 const RunStepRequestWrapper& req, const RunState* run_state,
1039 GraphExecutionState* execution_state) {
1040 // Build the set of pending feeds that we haven't seen.
1041 std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
1042 for (const auto& input : run_state->pending_inputs) {
1043 // Skip if already fed.
1044 if (input.second) continue;
1045 TensorId id(ParseTensorName(input.first));
1046 const Node* n = execution_state->get_node_by_name(string(id.first));
1047 if (n == nullptr) {
1048 return errors::NotFound("Feed ", input.first, ": not found");
1049 }
1050 pending_feeds.insert(id);
1051 }
1052 for (size_t i = 0; i < req.num_feeds(); ++i) {
1053 const TensorId id(ParseTensorName(req.feed_name(i)));
1054 pending_feeds.erase(id);
1055 }
1056
1057 // Initialize the stack with the fetch nodes.
1058 std::vector<const Node*> stack;
1059 for (size_t i = 0; i < req.num_fetches(); ++i) {
1060 const string& fetch = req.fetch_name(i);
1061 const TensorId id(ParseTensorName(fetch));
1062 const Node* n = execution_state->get_node_by_name(string(id.first));
1063 if (n == nullptr) {
1064 return errors::NotFound("Fetch ", fetch, ": not found");
1065 }
1066 stack.push_back(n);
1067 }
1068
1069 // Any tensor needed for fetches can't be in pending_feeds.
1070 // We need to use the original full graph from execution state.
1071 const Graph* graph = execution_state->full_graph();
1072 std::vector<bool> visited(graph->num_node_ids(), false);
1073 while (!stack.empty()) {
1074 const Node* n = stack.back();
1075 stack.pop_back();
1076
1077 for (const Edge* in_edge : n->in_edges()) {
1078 const Node* in_node = in_edge->src();
1079 if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
1080 return errors::InvalidArgument("Fetch ", in_node->name(), ":",
1081 in_edge->src_output(),
1082 " can't be computed from the feeds"
1083 " that have been fed so far.");
1084 }
1085 if (!visited[in_node->id()]) {
1086 visited[in_node->id()] = true;
1087 stack.push_back(in_node);
1088 }
1089 }
1090 }
1091 return OkStatus();
1092}
1093
1094// Asynchronously deregisters subgraphs on the workers, without waiting for the
1095// result.
1096void MasterSession::ReffedClientGraph::DeregisterPartitions() {
1097 struct Call {
1098 DeregisterGraphRequest req;
1099 DeregisterGraphResponse resp;
1100 };
1101 for (Part& part : partitions_) {
1102 // The graph handle may be empty if we failed during partition registration.
1103 if (!part.graph_handle.empty()) {
1104 Call* c = new Call;
1105 c->req.set_session_handle(session_handle_);
1106 c->req.set_create_worker_session_called(!should_deregister_);
1107 c->req.set_graph_handle(part.graph_handle);
1108 // NOTE(mrry): We must capture `worker_cache_` since `this`
1109 // could be deleted before the callback is called.
1110 WorkerCacheInterface* worker_cache = worker_cache_;
1111 const string name = part.name;
1112 WorkerInterface* w = part.worker;
1113 CHECK_NOTNULL(w);
1114 auto cb = [worker_cache, c, name, w](const Status& s) {
1115 if (!s.ok()) {
1116 // This error is potentially benign, so we don't log at the
1117 // error level.
1118 LOG(INFO) << "DeregisterGraph error: " << s;
1119 }
1120 delete c;
1121 worker_cache->ReleaseWorker(name, w);
1122 };
1123 w->DeregisterGraphAsync(&c->req, &c->resp, cb);
1124 }
1125 }
1126}
1127
1128namespace {
1129void CopyAndSortStrings(size_t size,
1130 const std::function<string(size_t)>& input_accessor,
1131 protobuf::RepeatedPtrField<string>* output) {
1132 std::vector<string> temp;
1133 temp.reserve(size);
1134 for (size_t i = 0; i < size; ++i) {
1135 output->Add(input_accessor(i));
1136 }
1137 std::sort(output->begin(), output->end());
1138}
1139} // namespace
1140
1141void BuildBuildGraphOptions(const RunStepRequestWrapper& req,
1142 const ConfigProto& config,
1143 BuildGraphOptions* opts) {
1144 CallableOptions* callable_opts = &opts->callable_options;
1145 CopyAndSortStrings(
1146 req.num_feeds(), [&req](size_t i) { return req.feed_name(i); },
1147 callable_opts->mutable_feed());
1148 CopyAndSortStrings(
1149 req.num_fetches(), [&req](size_t i) { return req.fetch_name(i); },
1150 callable_opts->mutable_fetch());
1151 CopyAndSortStrings(
1152 req.num_targets(), [&req](size_t i) { return req.target_name(i); },
1153 callable_opts->mutable_target());
1154
1155 if (!req.options().debug_options().debug_tensor_watch_opts().empty()) {
1156 *callable_opts->mutable_run_options()->mutable_debug_options() =
1157 req.options().debug_options();
1158 }
1159
1160 opts->collective_graph_key =
1161 req.options().experimental().collective_graph_key();
1162 if (config.experimental().collective_deterministic_sequential_execution()) {
1163 opts->collective_order = GraphCollectiveOrder::kEdges;
1164 } else if (config.experimental().collective_nccl()) {
1165 opts->collective_order = GraphCollectiveOrder::kAttrs;
1166 }
1167}
1168
1169void BuildBuildGraphOptions(const PartialRunSetupRequest& req,
1170 BuildGraphOptions* opts) {
1171 CallableOptions* callable_opts = &opts->callable_options;
1172 CopyAndSortStrings(
1173 req.feed_size(), [&req](size_t i) { return req.feed(i); },
1174 callable_opts->mutable_feed());
1175 CopyAndSortStrings(
1176 req.fetch_size(), [&req](size_t i) { return req.fetch(i); },
1177 callable_opts->mutable_fetch());
1178 CopyAndSortStrings(
1179 req.target_size(), [&req](size_t i) { return req.target(i); },
1180 callable_opts->mutable_target());
1181
1182 // TODO(cais): Add TFDBG support to partial runs.
1183}
1184
1185uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
1186 uint64 h = 0x2b992ddfa23249d6ull;
1187 for (const string& name : opts.callable_options.feed()) {
1188 h = Hash64(name.c_str(), name.size(), h);
1189 }
1190 for (const string& name : opts.callable_options.target()) {
1191 h = Hash64(name.c_str(), name.size(), h);
1192 }
1193 for (const string& name : opts.callable_options.fetch()) {
1194 h = Hash64(name.c_str(), name.size(), h);
1195 }
1196
1197 const DebugOptions& debug_options =
1198 opts.callable_options.run_options().debug_options();
1199 if (!debug_options.debug_tensor_watch_opts().empty()) {
1200 const string watch_summary =
1201 SummarizeDebugTensorWatches(debug_options.debug_tensor_watch_opts());
1202 h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
1203 }
1204
1205 return h;
1206}
1207
1208string BuildGraphOptionsString(const BuildGraphOptions& opts) {
1209 string buf;
1210 for (const string& name : opts.callable_options.feed()) {
1211 strings::StrAppend(&buf, " FdE: ", name);
1212 }
1213 strings::StrAppend(&buf, "\n");
1214 for (const string& name : opts.callable_options.target()) {
1215 strings::StrAppend(&buf, " TN: ", name);
1216 }
1217 strings::StrAppend(&buf, "\n");
1218 for (const string& name : opts.callable_options.fetch()) {
1219 strings::StrAppend(&buf, " FeE: ", name);
1220 }
1221 if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
1222 strings::StrAppend(&buf, "\nGK: ", opts.collective_graph_key);
1223 }
1224 strings::StrAppend(&buf, "\n");
1225 return buf;
1226}
1227
1228MasterSession::MasterSession(
1229 const SessionOptions& opt, const MasterEnv* env,
1230 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
1231 std::unique_ptr<WorkerCacheInterface> worker_cache,
1232 std::unique_ptr<DeviceSet> device_set,
1233 std::vector<string> filtered_worker_list,
1234 StatsPublisherFactory stats_publisher_factory)
1235 : session_opts_(opt),
1236 env_(env),
1237 handle_(strings::FpToString(random::New64())),
1238 remote_devs_(std::move(remote_devs)),
1239 worker_cache_(std::move(worker_cache)),
1240 devices_(std::move(device_set)),
1241 filtered_worker_list_(std::move(filtered_worker_list)),
1242 stats_publisher_factory_(std::move(stats_publisher_factory)),
1243 graph_version_(0),
1244 run_graphs_(5),
1245 partial_run_graphs_(5) {
1246 UpdateLastAccessTime();
1247 CHECK(devices_) << "device_set was null!";
1248
1249 VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size()
1250 << " #remote " << remote_devs_->size();
1251 VLOG(1) << "Start master session " << handle_
1252 << " with config: " << session_opts_.config.ShortDebugString();
1253}
1254
1255MasterSession::~MasterSession() {
1256 for (const auto& iter : run_graphs_) iter.second->Unref();
1257 for (const auto& iter : partial_run_graphs_) iter.second->Unref();
1258}
1259
1260void MasterSession::UpdateLastAccessTime() {
1261 last_access_time_usec_.store(Env::Default()->NowMicros());
1262}
1263
1264Status MasterSession::Create(GraphDef&& graph_def,
1265 const ClusterDef& cluster_def) {
1266 if (session_opts_.config.use_per_session_threads() ||
1267 session_opts_.config.session_inter_op_thread_pool_size() > 0) {
1268 return errors::InvalidArgument(
1269 "Distributed session does not support session thread pool options.");
1270 }
1271 if (session_opts_.config.graph_options().place_pruned_graph()) {
1272 // TODO(b/29900832): Fix this or remove the option.
1273 LOG(WARNING) << "Distributed session does not support the "
1274 "place_pruned_graph option.";
1275 session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false);
1276 }
1277
1278 GraphExecutionStateOptions execution_options;
1279 execution_options.device_set = devices_.get();
1280 execution_options.session_options = &session_opts_;
1281 {
1282 mutex_lock l(mu_);
1283 TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
1284 std::move(graph_def), execution_options, &execution_state_));
1285 }
1286 should_delete_worker_sessions_ = true;
1287 return CreateWorkerSessions(cluster_def);
1288}
1289
1290Status MasterSession::CreateWorkerSessions(const ClusterDef& cluster_def) {
1291 const std::vector<string> worker_names = filtered_worker_list_;
1292 WorkerCacheInterface* worker_cache = get_worker_cache();
1293
1294 struct WorkerGroup {
1295 // The worker name. (Not owned.)
1296 const string* name;
1297
1298 // The worker referenced by name. (Not owned.)
1299 WorkerInterface* worker = nullptr;
1300
1301 // Request and responses used for a given worker.
1302 CreateWorkerSessionRequest request;
1303 CreateWorkerSessionResponse response;
1304 Status status = OkStatus();
1305 };
1306 BlockingCounter done(worker_names.size());
1307 std::vector<WorkerGroup> workers(worker_names.size());
1308
1309 // Release the workers.
1310 auto cleanup = gtl::MakeCleanup([&workers, worker_cache] {
1311 for (auto&& worker_group : workers) {
1312 if (worker_group.worker != nullptr) {
1313 worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
1314 }
1315 }
1316 });
1317
1318 string task_name;
1319 string local_device_name;
1320 DeviceNameUtils::SplitDeviceName(devices_->client_device()->name(),
1321 &task_name, &local_device_name);
1322 const int64_t client_device_incarnation =
1323 devices_->client_device()->attributes().incarnation();
1324
1325 Status status = OkStatus();
1326 // Create all the workers & kick off the computations.
1327 for (size_t i = 0; i < worker_names.size(); ++i) {
1328 workers[i].name = &worker_names[i];
1329 workers[i].worker = worker_cache->GetOrCreateWorker(worker_names[i]);
1330 workers[i].request.set_session_handle(handle_);
1331 workers[i].request.set_master_task(task_name);
1332 workers[i].request.set_master_incarnation(client_device_incarnation);
1333 if (session_opts_.config.share_cluster_devices_in_session() ||
1334 session_opts_.config.experimental()
1335 .share_cluster_devices_in_session()) {
1336 for (const auto& remote_dev : devices_->devices()) {
1337 *workers[i].request.add_cluster_device_attributes() =
1338 remote_dev->attributes();
1339 }
1340
1341 if (!session_opts_.config.share_cluster_devices_in_session() &&
1342 session_opts_.config.experimental()
1343 .share_cluster_devices_in_session()) {
1344 LOG(WARNING)
1345 << "ConfigProto.Experimental.share_cluster_devices_in_session has "
1346 "been promoted to a non-experimental API. Please use "
1347 "ConfigProto.share_cluster_devices_in_session instead. The "
1348 "experimental option will be removed in the future.";
1349 }
1350 }
1351
1352 DeviceNameUtils::ParsedName name;
1353 if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) {
1354 status = errors::Internal("Could not parse name ", worker_names[i]);
1355 LOG(WARNING) << status;
1356 return status;
1357 }
1358 if (!name.has_job || !name.has_task) {
1359 status = errors::Internal("Incomplete worker name ", worker_names[i]);
1360 LOG(WARNING) << status;
1361 return status;
1362 }
1363
1364 workers[i].request.mutable_server_def()->set_protocol("grpc");
1365 workers[i].request.mutable_server_def()->set_job_name(name.job);
1366 workers[i].request.mutable_server_def()->set_task_index(name.task);
1367 if (!cluster_def.job().empty()) {
1368 *workers[i].request.mutable_server_def()->mutable_cluster() = cluster_def;
1369 // Session state is always isolated when ClusterSpec propagation
1370 // is in use.
1371 workers[i].request.set_isolate_session_state(true);
1372 } else {
1373 // NOTE(mrry): Do not set any component of the ServerDef,
1374 // because the worker will use its local configuration.
1375 workers[i].request.set_isolate_session_state(
1376 session_opts_.config.isolate_session_state());
1377 }
1378 CoordinationServiceConfig coordination_config;
1379 // Enable coordination service in session options by default if
1380 // unspecified in non-local targets.
1381 if (session_opts_.target != "local" &&
1382 !session_opts_.config.experimental().has_coordination_config()) {
1383 coordination_config.set_service_type("standalone");
1384 } else {
1385 coordination_config =
1386 session_opts_.config.experimental().coordination_config();
1387 }
1388 // Specify master task as coordination service leader.
1389 coordination_config.set_service_leader(task_name);
1390 *workers[i]
1391 .request.mutable_server_def()
1392 ->mutable_default_session_config()
1393 ->mutable_experimental()
1394 ->mutable_coordination_config() = coordination_config;
1395
1396 if (session_opts_.config.experimental()
1397 .share_session_state_in_clusterspec_propagation()) {
1398 // In a dynamic cluster, the ClusterSpec info is usually propagated by
1399 // master sessions. However, in data parallel training with multiple
1400 // masters
1401 // ("between-graph replication"), we need to disable isolation for
1402 // different worker sessions to update the same variables in PS tasks.
1403 workers[i].request.set_isolate_session_state(false);
1404 }
1405 }
1406
1407 for (size_t i = 0; i < worker_names.size(); ++i) {
1408 auto cb = [i, &workers, &done](const Status& s) {
1409 workers[i].status = s;
1410 done.DecrementCount();
1411 };
1412 workers[i].worker->CreateWorkerSessionAsync(&workers[i].request,
1413 &workers[i].response, cb);
1414 }
1415
1416 done.Wait();
1417 for (size_t i = 0; i < workers.size(); ++i) {
1418 status.Update(workers[i].status);
1419 }
1420 return status;
1421}
1422
1423Status MasterSession::DeleteWorkerSessions() {
1424 WorkerCacheInterface* worker_cache = get_worker_cache();
1425 const std::vector<string>& worker_names = filtered_worker_list_;
1426
1427 struct WorkerGroup {
1428 // The worker name. (Not owned.)
1429 const string* name;
1430
1431 // The worker referenced by name. (Not owned.)
1432 WorkerInterface* worker = nullptr;
1433
1434 CallOptions call_opts;
1435
1436 // Request and responses used for a given worker.
1437 DeleteWorkerSessionRequest request;
1438 DeleteWorkerSessionResponse response;
1439 Status status = OkStatus();
1440 };
1441 BlockingCounter done(worker_names.size());
1442 std::vector<WorkerGroup> workers(worker_names.size());
1443
1444 // Release the workers.
1445 auto cleanup = gtl::MakeCleanup([&workers, worker_cache] {
1446 for (auto&& worker_group : workers) {
1447 if (worker_group.worker != nullptr) {
1448 worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
1449 }
1450 }
1451 });
1452
1453 Status status = OkStatus();
1454 // Create all the workers & kick off the computations.
1455 for (size_t i = 0; i < worker_names.size(); ++i) {
1456 workers[i].name = &worker_names[i];
1457 workers[i].worker = worker_cache->GetOrCreateWorker(worker_names[i]);
1458 workers[i].request.set_session_handle(handle_);
1459 // Since the worker may have gone away, set a timeout to avoid blocking the
1460 // session-close operation.
1461 workers[i].call_opts.SetTimeout(10000);
1462 }
1463
1464 for (size_t i = 0; i < worker_names.size(); ++i) {
1465 auto cb = [i, &workers, &done](const Status& s) {
1466 workers[i].status = s;
1467 done.DecrementCount();
1468 };
1469 workers[i].worker->DeleteWorkerSessionAsync(
1470 &workers[i].call_opts, &workers[i].request, &workers[i].response, cb);
1471 }
1472
1473 done.Wait();
1474 for (size_t i = 0; i < workers.size(); ++i) {
1475 status.Update(workers[i].status);
1476 }
1477 return status;
1478}
1479
1480Status MasterSession::ListDevices(ListDevicesResponse* resp) const {
1481 if (worker_cache_) {
1482 // This is a ClusterSpec-propagated session, and thus env_->local_devices
1483 // are invalid.
1484
1485 // Mark the "client_device" as the sole local device.
1486 const Device* client_device = devices_->client_device();
1487 for (const Device* dev : devices_->devices()) {
1488 if (dev != client_device) {
1489 *(resp->add_remote_device()) = dev->attributes();
1490 }
1491 }
1492 *(resp->add_local_device()) = client_device->attributes();
1493 } else {
1494 for (Device* dev : env_->local_devices) {
1495 *(resp->add_local_device()) = dev->attributes();
1496 }
1497 for (auto&& dev : *remote_devs_) {
1498 *(resp->add_local_device()) = dev->attributes();
1499 }
1500 }
1501 return OkStatus();
1502}
1503
1504Status MasterSession::Extend(const ExtendSessionRequest* req,
1505 ExtendSessionResponse* resp) {
1506 UpdateLastAccessTime();
1507 std::unique_ptr<GraphExecutionState> extended_execution_state;
1508 {
1509 mutex_lock l(mu_);
1510 if (closed_) {
1511 return errors::FailedPrecondition("Session is closed.");
1512 }
1513
1514 if (graph_version_ != req->current_graph_version()) {
1515 return errors::Aborted("Current version is ", graph_version_,
1516 " but caller expected ",
1517 req->current_graph_version(), ".");
1518 }
1519
1520 CHECK(execution_state_);
1521 TF_RETURN_IF_ERROR(
1522 execution_state_->Extend(req->graph_def(), &extended_execution_state));
1523
1524 CHECK(extended_execution_state);
1525 // The old execution state will be released outside the lock.
1526 execution_state_.swap(extended_execution_state);
1527 ++graph_version_;
1528 resp->set_new_graph_version(graph_version_);
1529 }
1530 return OkStatus();
1531}
1532
1533WorkerCacheInterface* MasterSession::get_worker_cache() const {
1534 if (worker_cache_) {
1535 return worker_cache_.get();
1536 }
1537 return env_->worker_cache;
1538}
1539
1540Status MasterSession::StartStep(const BuildGraphOptions& opts, bool is_partial,
1541 ReffedClientGraph** out_rcg,
1542 int64_t* out_count) {
1543 const uint64 hash = HashBuildGraphOptions(opts);
1544 {
1545 mutex_lock l(mu_);
1546 // TODO(suharshs): We cache partial run graphs and run graphs separately
1547 // because there is preprocessing that needs to only be run for partial
1548 // run calls.
1549 RCGMap* m = is_partial ? &partial_run_graphs_ : &run_graphs_;
1550 auto iter = m->find(hash);
1551 if (iter == m->end()) {
1552 // We have not seen this subgraph before. Build the subgraph and
1553 // cache it.
1554 VLOG(1) << "Unseen hash " << hash << " for "
1555 << BuildGraphOptionsString(opts) << " is_partial = " << is_partial
1556 << "\n";
1557 std::unique_ptr<ClientGraph> client_graph;
1558 TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
1559 WorkerCacheInterface* worker_cache = get_worker_cache();
1560 auto entry = new ReffedClientGraph(
1561 handle_, opts, std::move(client_graph), session_opts_,
1562 stats_publisher_factory_, is_partial, worker_cache,
1563 !should_delete_worker_sessions_);
1564 iter = m->insert({hash, entry}).first;
1565 VLOG(1) << "Preparing to execute new graph";
1566 }
1567 *out_rcg = iter->second;
1568 (*out_rcg)->Ref();
1569 *out_count = (*out_rcg)->get_and_increment_execution_count();
1570 }
1571 return OkStatus();
1572}
1573
1574void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
1575 RCGMap* rcg_map) {
1576 VLOG(1) << "Discarding all reffed graphs";
1577 for (auto p : *rcg_map) {
1578 ReffedClientGraph* rcg = p.second;
1579 if (to_unref) {
1580 to_unref->push_back(rcg);
1581 } else {
1582 rcg->Unref();
1583 }
1584 }
1585 rcg_map->clear();
1586}
1587
1588uint64 MasterSession::NewStepId(int64_t graph_key) {
1589 if (graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
1590 // StepId must leave the most-significant 7 bits empty for future use.
1591 return random::New64() & (((1uLL << 56) - 1) | (1uLL << 56));
1592 } else {
1593 uint64 step_id = env_->collective_executor_mgr->NextStepId(graph_key);
1594 int32_t retry_count = 0;
1595 while (static_cast<int64_t>(step_id) == CollectiveExecutor::kInvalidId) {
1596 Notification note;
1597 Status status;
1598 env_->collective_executor_mgr->RefreshStepIdSequenceAsync(
1599 graph_key, [&status, &note](const Status& s) {
1600 status = s;
1601 note.Notify();
1602 });
1603 note.WaitForNotification();
1604 if (!status.ok()) {
1605 LOG(ERROR) << "Bad status from "
1606 "collective_executor_mgr->RefreshStepIdSequence: "
1607 << status << ". Retrying.";
1608 int64_t delay_micros = std::min(60000000LL, 1000000LL * ++retry_count);
1609 Env::Default()->SleepForMicroseconds(delay_micros);
1610 } else {
1611 step_id = env_->collective_executor_mgr->NextStepId(graph_key);
1612 }
1613 }
1614 return step_id;
1615 }
1616}
1617
1618Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
1619 PartialRunSetupResponse* resp) {
1620 std::vector<string> inputs, outputs, targets;
1621 for (const auto& feed : req->feed()) {
1622 inputs.push_back(feed);
1623 }
1624 for (const auto& fetch : req->fetch()) {
1625 outputs.push_back(fetch);
1626 }
1627 for (const auto& target : req->target()) {
1628 targets.push_back(target);
1629 }
1630
1631 string handle = std::to_string(partial_run_handle_counter_.fetch_add(1));
1632
1633 ReffedClientGraph* rcg = nullptr;
1634
1635 // Prepare.
1636 BuildGraphOptions opts;
1637 BuildBuildGraphOptions(*req, &opts);
1638 int64_t count = 0;
1639 TF_RETURN_IF_ERROR(StartStep(opts, true, &rcg, &count));
1640
1641 rcg->Ref();
1642 RunState* run_state =
1643 new RunState(inputs, outputs, rcg,
1644 NewStepId(BuildGraphOptions::kNoCollectiveGraphKey), count);
1645 {
1646 mutex_lock l(mu_);
1647 partial_runs_.emplace(
1648 std::make_pair(handle, std::unique_ptr<RunState>(run_state)));
1649 }
1650
1651 TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
1652
1653 resp->set_partial_run_handle(handle);
1654 return OkStatus();
1655}
1656
1657Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
1658 MutableRunStepResponseWrapper* resp) {
1659 UpdateLastAccessTime();
1660 {
1661 mutex_lock l(mu_);
1662 if (closed_) {
1663 return errors::FailedPrecondition("Session is closed.");
1664 }
1665 ++num_running_;
1666 // Note: all code paths must eventually call MarkRunCompletion()
1667 // in order to appropriate decrement the num_running_ counter.
1668 }
1669 Status status;
1670 if (!req.partial_run_handle().empty()) {
1671 status = DoPartialRun(opts, req, resp);
1672 } else {
1673 status = DoRunWithLocalExecution(opts, req, resp);
1674 }
1675 return status;
1676}
1677
1678// Decrements num_running_ and broadcasts if num_running_ is zero.
1679void MasterSession::MarkRunCompletion() {
1680 mutex_lock l(mu_);
1681 --num_running_;
1682 if (num_running_ == 0) {
1683 num_running_is_zero_.notify_all();
1684 }
1685}
1686
1687Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
1688 // Registers subgraphs if haven't done so.
1689 PartitionOptions popts;
1690 popts.node_to_loc = SplitByWorker;
1691 // The closures popts.{new_name,get_incarnation} are called synchronously in
1692 // RegisterPartitions() below, so do not need a Ref()/Unref() pair to keep
1693 // "this" alive during the closure.
1694 popts.new_name = [this](const string& prefix) {
1695 mutex_lock l(mu_);
1696 return strings::StrCat(prefix, "_S", next_node_id_++);
1697 };
1698 popts.get_incarnation = [this](const string& name) -> int64 {
1699 Device* d = devices_->FindDeviceByName(name);
1700 if (d == nullptr) {
1701 return PartitionOptions::kIllegalIncarnation;
1702 } else {
1703 return d->attributes().incarnation();
1704 }
1705 };
1706 popts.control_flow_added = false;
1707 const bool enable_bfloat16_sendrecv =
1708 session_opts_.config.graph_options().enable_bfloat16_sendrecv();
1709 popts.should_cast = [enable_bfloat16_sendrecv](const Edge* e) {
1710 if (e->IsControlEdge()) {
1711 return DT_FLOAT;
1712 }
1713 DataType dtype = BaseType(e->src()->output_type(e->src_output()));
1714 if (enable_bfloat16_sendrecv && dtype == DT_FLOAT) {
1715 return DT_BFLOAT16;
1716 } else {
1717 return dtype;
1718 }
1719 };
1720 if (session_opts_.config.graph_options().enable_recv_scheduling()) {
1721 popts.scheduling_for_recvs = true;
1722 popts.need_to_record_start_times = true;
1723 }
1724
1725 TF_RETURN_IF_ERROR(rcg->RegisterPartitions(std::move(popts)));
1726
1727 return OkStatus();
1728}
1729
1730Status MasterSession::DoPartialRun(CallOptions* opts,
1731 const RunStepRequestWrapper& req,
1732 MutableRunStepResponseWrapper* resp) {
1733 auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
1734 const string& prun_handle = req.partial_run_handle();
1735 RunState* run_state = nullptr;
1736 {
1737 mutex_lock l(mu_);
1738 auto it = partial_runs_.find(prun_handle);
1739 if (it == partial_runs_.end()) {
1740 return errors::InvalidArgument(
1741 "Must run PartialRunSetup before performing partial runs");
1742 }
1743 run_state = it->second.get();
1744 }
1745 // CollectiveOps are not supported in partial runs.
1746 if (req.options().experimental().collective_graph_key() !=
1747 BuildGraphOptions::kNoCollectiveGraphKey) {
1748 return errors::InvalidArgument(
1749 "PartialRun does not support Collective ops. collective_graph_key "
1750 "must be kNoCollectiveGraphKey.");
1751 }
1752
1753 // If this is the first partial run, initialize the PerStepState.
1754 if (!run_state->step_started) {
1755 run_state->step_started = true;
1756 PerStepState pss;
1757
1758 const auto count = run_state->count;
1759 pss.collect_timeline =
1760 req.options().trace_level() == RunOptions::FULL_TRACE;
1761 pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE;
1762 pss.report_tensor_allocations_upon_oom =
1763 req.options().report_tensor_allocations_upon_oom();
1764
1765 // Build the cost model every 'build_cost_model_every' steps after skipping
1766 // an
1767 // initial 'build_cost_model_after' steps.
1768 const int64_t build_cost_model_after =
1769 session_opts_.config.graph_options().build_cost_model_after();
1770 const int64_t build_cost_model_every =
1771 session_opts_.config.graph_options().build_cost_model();
1772 pss.collect_costs =
1773 build_cost_model_every > 0 &&
1774 ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
1775 pss.collect_partition_graphs = req.options().output_partition_graphs();
1776
1777 std::unique_ptr<ProfileHandler> ph = run_state->rcg->GetProfileHandler(
1778 run_state->step_id, count, req.options());
1779 if (ph) {
1780 pss.collect_timeline = true;
1781 pss.collect_rpcs = ph->should_collect_rpcs();
1782 }
1783
1784 run_state->pss = std::move(pss);
1785 run_state->ph = std::move(ph);
1786 }
1787
1788 // Make sure that this is a new set of feeds that are still pending.
1789 for (size_t i = 0; i < req.num_feeds(); ++i) {
1790 const string& feed = req.feed_name(i);
1791 auto it = run_state->pending_inputs.find(feed);
1792 if (it == run_state->pending_inputs.end()) {
1793 return errors::InvalidArgument(
1794 "The feed ", feed, " was not specified in partial_run_setup.");
1795 } else if (it->second) {
1796 return errors::InvalidArgument("The feed ", feed,
1797 " has already been fed.");
1798 }
1799 }
1800 // Check that this is a new set of fetches that are still pending.
1801 for (size_t i = 0; i < req.num_fetches(); ++i) {
1802 const string& fetch = req.fetch_name(i);
1803 auto it = run_state->pending_outputs.find(fetch);
1804 if (it == run_state->pending_outputs.end()) {
1805 return errors::InvalidArgument(
1806 "The fetch ", fetch, " was not specified in partial_run_setup.");
1807 } else if (it->second) {
1808 return errors::InvalidArgument("The fetch ", fetch,
1809 " has already been fetched.");
1810 }
1811 }
1812
1813 // Ensure that the requested fetches can be computed from the provided feeds.
1814 {
1815 mutex_lock l(mu_);
1816 TF_RETURN_IF_ERROR(
1817 run_state->rcg->CheckFetches(req, run_state, execution_state_.get()));
1818 }
1819
1820 // Determine if this partial run satisfies all the pending inputs and outputs.
1821 for (size_t i = 0; i < req.num_feeds(); ++i) {
1822 auto it = run_state->pending_inputs.find(req.feed_name(i));
1823 it->second = true;
1824 }
1825 for (size_t i = 0; i < req.num_fetches(); ++i) {
1826 auto it = run_state->pending_outputs.find(req.fetch_name(i));
1827 it->second = true;
1828 }
1829 bool is_last_partial_run = run_state->PendingDone();
1830
1831 Status s = run_state->rcg->RunPartitions(
1832 env_, run_state->step_id, run_state->count, &run_state->pss, opts, req,
1833 resp, &cancellation_manager_, is_last_partial_run);
1834
1835 // Delete the run state if there is an error or all fetches are done.
1836 if (!s.ok() || is_last_partial_run) {
1837 ReffedClientGraph* rcg = run_state->rcg;
1838 run_state->pss.end_micros = Env::Default()->NowMicros();
1839 // Schedule post-processing and cleanup to be done asynchronously.
1840 Ref();
1841 rcg->Ref();
1842 rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(),
1843 req.options(), resp->mutable_metadata());
1844 cleanup.release(); // MarkRunCompletion called in done closure.
1845 rcg->CleanupPartitionsAsync(
1846 run_state->step_id, [this, rcg, prun_handle](const Status& s) {
1847 if (!s.ok()) {
1848 LOG(ERROR) << "Cleanup partition error: " << s;
1849 }
1850 rcg->Unref();
1851 MarkRunCompletion();
1852 Unref();
1853 });
1854 mutex_lock l(mu_);
1855 partial_runs_.erase(prun_handle);
1856 }
1857 return s;
1858}
1859
1860Status MasterSession::CreateDebuggerState(
1861 const DebugOptions& debug_options, const RunStepRequestWrapper& req,
1862 int64_t rcg_execution_count,
1863 std::unique_ptr<DebuggerStateInterface>* debugger_state) {
1864 TF_RETURN_IF_ERROR(
1865 DebuggerStateRegistry::CreateState(debug_options, debugger_state));
1866
1867 std::vector<string> input_names;
1868 for (size_t i = 0; i < req.num_feeds(); ++i) {
1869 input_names.push_back(req.feed_name(i));
1870 }
1871 std::vector<string> output_names;
1872 for (size_t i = 0; i < req.num_fetches(); ++i) {
1873 output_names.push_back(req.fetch_name(i));
1874 }
1875 std::vector<string> target_names;
1876 for (size_t i = 0; i < req.num_targets(); ++i) {
1877 target_names.push_back(req.target_name(i));
1878 }
1879
1880 // TODO(cais): We currently use -1 as a dummy value for session run count.
1881 // While this counter value is straightforward to define and obtain for
1882 // DirectSessions, it is less so for non-direct Sessions. Devise a better
1883 // way to get its value when the need arises.
1884 TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
1885 debug_options.global_step(), rcg_execution_count, rcg_execution_count,
1886 input_names, output_names, target_names));
1887
1888 return OkStatus();
1889}
1890
1891void MasterSession::FillPerStepState(MasterSession::ReffedClientGraph* rcg,
1892 const RunOptions& run_options,
1893 uint64 step_id, int64_t count,
1894 PerStepState* out_pss,
1895 std::unique_ptr<ProfileHandler>* out_ph) {
1896 out_pss->collect_timeline =
1897 run_options.trace_level() == RunOptions::FULL_TRACE;
1898 out_pss->collect_rpcs = run_options.trace_level() == RunOptions::FULL_TRACE;
1899 out_pss->report_tensor_allocations_upon_oom =
1900 run_options.report_tensor_allocations_upon_oom();
1901 // Build the cost model every 'build_cost_model_every' steps after skipping an
1902 // initial 'build_cost_model_after' steps.
1903 const int64_t build_cost_model_after =
1904 session_opts_.config.graph_options().build_cost_model_after();
1905 const int64_t build_cost_model_every =
1906 session_opts_.config.graph_options().build_cost_model();
1907 out_pss->collect_costs =
1908 build_cost_model_every > 0 &&
1909 ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
1910 out_pss->collect_partition_graphs = run_options.output_partition_graphs();
1911
1912 *out_ph = rcg->GetProfileHandler(step_id, count, run_options);
1913 if (*out_ph) {
1914 out_pss->collect_timeline = true;
1915 out_pss->collect_rpcs = (*out_ph)->should_collect_rpcs();
1916 }
1917}
1918
1919Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
1920 uint64 step_id,
1921 const RunOptions& run_options,
1922 PerStepState* pss,
1923 const std::unique_ptr<ProfileHandler>& ph,
1924 const Status& run_status,
1925 RunMetadata* out_run_metadata) {
1926 Status s = run_status;
1927 if (s.ok()) {
1928 pss->end_micros = Env::Default()->NowMicros();
1929 if (rcg->collective_graph_key() !=
1930 BuildGraphOptions::kNoCollectiveGraphKey) {
1931 env_->collective_executor_mgr->RetireStepId(rcg->collective_graph_key(),
1932 step_id);
1933 }
1934 // Schedule post-processing and cleanup to be done asynchronously.
1935 rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
1936 } else if (errors::IsCancelled(s)) {
1937 mutex_lock l(mu_);
1938 if (closed_) {
1939 if (garbage_collected_) {
1940 s = errors::Cancelled(
1941 "Step was cancelled because the session was garbage collected due "
1942 "to inactivity.");
1943 } else {
1944 s = errors::Cancelled(
1945 "Step was cancelled by an explicit call to `Session::Close()`.");
1946 }
1947 }
1948 }
1949 Ref();
1950 rcg->Ref();
1951 rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
1952 if (!s.ok()) {
1953 LOG(ERROR) << "Cleanup partition error: " << s;
1954 }
1955 rcg->Unref();
1956 MarkRunCompletion();
1957 Unref();
1958 });
1959 return s;
1960}
1961
1962Status MasterSession::DoRunWithLocalExecution(
1963 CallOptions* opts, const RunStepRequestWrapper& req,
1964 MutableRunStepResponseWrapper* resp) {
1965 VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString();
1966 PerStepState pss;
1967 pss.start_micros = Env::Default()->NowMicros();
1968 auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
1969
1970 // Prepare.
1971 BuildGraphOptions bgopts;
1972 BuildBuildGraphOptions(req, session_opts_.config, &bgopts);
1973 ReffedClientGraph* rcg = nullptr;
1974 int64_t count;
1975 TF_RETURN_IF_ERROR(StartStep(bgopts, false, &rcg, &count));
1976
1977 // Unref "rcg" when out of scope.
1978 core::ScopedUnref unref(rcg);
1979
1980 std::unique_ptr<DebuggerStateInterface> debugger_state;
1981 const DebugOptions& debug_options = req.options().debug_options();
1982
1983 if (!debug_options.debug_tensor_watch_opts().empty()) {
1984 TF_RETURN_IF_ERROR(
1985 CreateDebuggerState(debug_options, req, count, &debugger_state));
1986 }
1987 TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
1988
1989 // Keeps the highest 8 bits 0x01: we reserve some bits of the
1990 // step_id for future use.
1991 uint64 step_id = NewStepId(rcg->collective_graph_key());
1992 TRACEPRINTF("stepid %llu", step_id);
1993
1994 std::unique_ptr<ProfileHandler> ph;
1995 FillPerStepState(rcg, req.options(), step_id, count, &pss, &ph);
1996
1997 if (pss.collect_partition_graphs &&
1998 session_opts_.config.experimental().disable_output_partition_graphs()) {
1999 return errors::InvalidArgument(
2000 "RunOptions.output_partition_graphs() is not supported when "
2001 "disable_output_partition_graphs is true.");
2002 }
2003
2004 Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
2005 &cancellation_manager_, false);
2006
2007 cleanup.release(); // MarkRunCompletion called in PostRunCleanup().
2008 return PostRunCleanup(rcg, step_id, req.options(), &pss, ph, s,
2009 resp->mutable_metadata());
2010}
2011
2012Status MasterSession::MakeCallable(const MakeCallableRequest& req,
2013 MakeCallableResponse* resp) {
2014 UpdateLastAccessTime();
2015
2016 BuildGraphOptions opts;
2017 opts.callable_options = req.options();
2018 opts.use_function_convention = false;
2019
2020 ReffedClientGraph* callable;
2021
2022 {
2023 mutex_lock l(mu_);
2024 if (closed_) {
2025 return errors::FailedPrecondition("Session is closed.");
2026 }
2027 std::unique_ptr<ClientGraph> client_graph;
2028 TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
2029 callable = new ReffedClientGraph(handle_, opts, std::move(client_graph),
2030 session_opts_, stats_publisher_factory_,
2031 false /* is_partial */, get_worker_cache(),
2032 !should_delete_worker_sessions_);
2033 }
2034
2035 Status s = BuildAndRegisterPartitions(callable);
2036 if (!s.ok()) {
2037 callable->Unref();
2038 return s;
2039 }
2040
2041 uint64 handle;
2042 {
2043 mutex_lock l(mu_);
2044 handle = next_callable_handle_++;
2045 callables_[handle] = callable;
2046 }
2047
2048 resp->set_handle(handle);
2049 return OkStatus();
2050}
2051
2052Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
2053 const RunCallableRequest& req,
2054 RunCallableResponse* resp) {
2055 VLOG(2) << "DoRunCallable req: " << req.DebugString();
2056 PerStepState pss;
2057 pss.start_micros = Env::Default()->NowMicros();
2058 auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
2059
2060 // Prepare.
2061 int64_t count = rcg->get_and_increment_execution_count();
2062
2063 const uint64 step_id = NewStepId(rcg->collective_graph_key());
2064 TRACEPRINTF("stepid %llu", step_id);
2065
2066 const RunOptions& run_options = rcg->callable_options().run_options();
2067
2068 if (run_options.timeout_in_ms() != 0) {
2069 opts->SetTimeout(run_options.timeout_in_ms());
2070 }
2071
2072 std::unique_ptr<ProfileHandler> ph;
2073 FillPerStepState(rcg, run_options, step_id, count, &pss, &ph);
2074 Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
2075 &cancellation_manager_);
2076 cleanup.release(); // MarkRunCompletion called in PostRunCleanup().
2077 return PostRunCleanup(rcg, step_id, run_options, &pss, ph, s,
2078 resp->mutable_metadata());
2079}
2080
2081Status MasterSession::RunCallable(CallOptions* opts,
2082 const RunCallableRequest& req,
2083 RunCallableResponse* resp) {
2084 UpdateLastAccessTime();
2085 ReffedClientGraph* callable;
2086 {
2087 mutex_lock l(mu_);
2088 if (closed_) {
2089 return errors::FailedPrecondition("Session is closed.");
2090 }
2091 int64_t handle = req.handle();
2092 if (handle >= next_callable_handle_) {
2093 return errors::InvalidArgument("No such callable handle: ", handle);
2094 }
2095 auto iter = callables_.find(req.handle());
2096 if (iter == callables_.end()) {
2097 return errors::InvalidArgument(
2098 "Attempted to run callable after handle was released: ", handle);
2099 }
2100 callable = iter->second;
2101 callable->Ref();
2102 ++num_running_;
2103 }
2104 core::ScopedUnref unref_callable(callable);
2105 return DoRunCallable(opts, callable, req, resp);
2106}
2107
2108Status MasterSession::ReleaseCallable(const ReleaseCallableRequest& req,
2109 ReleaseCallableResponse* resp) {
2110 UpdateLastAccessTime();
2111 ReffedClientGraph* to_unref = nullptr;
2112 {
2113 mutex_lock l(mu_);
2114 auto iter = callables_.find(req.handle());
2115 if (iter != callables_.end()) {
2116 to_unref = iter->second;
2117 callables_.erase(iter);
2118 }
2119 }
2120 if (to_unref != nullptr) {
2121 to_unref->Unref();
2122 }
2123 return OkStatus();
2124}
2125
2126Status MasterSession::Close() {
2127 {
2128 mutex_lock l(mu_);
2129 closed_ = true; // All subsequent calls to Run() or Extend() will fail.
2130 }
2131 cancellation_manager_.StartCancel();
2132 std::vector<ReffedClientGraph*> to_unref;
2133 {
2134 mutex_lock l(mu_);
2135 while (num_running_ != 0) {
2136 num_running_is_zero_.wait(l);
2137 }
2138 ClearRunsTable(&to_unref, &run_graphs_);
2139 ClearRunsTable(&to_unref, &partial_run_graphs_);
2140 ClearRunsTable(&to_unref, &callables_);
2141 }
2142 for (ReffedClientGraph* rcg : to_unref) rcg->Unref();
2143 if (should_delete_worker_sessions_) {
2144 Status s = DeleteWorkerSessions();
2145 if (!s.ok()) {
2146 LOG(WARNING) << s;
2147 }
2148 }
2149 return OkStatus();
2150}
2151
2152void MasterSession::GarbageCollect() {
2153 {
2154 mutex_lock l(mu_);
2155 closed_ = true;
2156 garbage_collected_ = true;
2157 }
2158 cancellation_manager_.StartCancel();
2159 Unref();
2160}
2161
2162MasterSession::RunState::RunState(const std::vector<string>& input_names,
2163 const std::vector<string>& output_names,
2164 ReffedClientGraph* rcg, const uint64 step_id,
2165 const int64_t count)
2166 : rcg(rcg), step_id(step_id), count(count) {
2167 // Initially all the feeds and fetches are pending.
2168 for (auto& name : input_names) {
2169 pending_inputs[name] = false;
2170 }
2171 for (auto& name : output_names) {
2172 pending_outputs[name] = false;
2173 }
2174}
2175
2176MasterSession::RunState::~RunState() {
2177 if (rcg) rcg->Unref();
2178}
2179
2180bool MasterSession::RunState::PendingDone() const {
2181 for (const auto& it : pending_inputs) {
2182 if (!it.second) return false;
2183 }
2184 for (const auto& it : pending_outputs) {
2185 if (!it.second) return false;
2186 }
2187 return true;
2188}
2189
2190} // end namespace tensorflow
2191