1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/distributed_runtime/master_session.h" |
17 | |
18 | #include <algorithm> |
19 | #include <functional> |
20 | #include <memory> |
21 | #include <string> |
22 | #include <unordered_map> |
23 | #include <unordered_set> |
24 | #include <utility> |
25 | #include <vector> |
26 | |
27 | #include "tensorflow/core/common_runtime/process_util.h" |
28 | #include "tensorflow/core/common_runtime/profile_handler.h" |
29 | #include "tensorflow/core/common_runtime/stats_publisher_interface.h" |
30 | #include "tensorflow/core/debug/debug_graph_utils.h" |
31 | #include "tensorflow/core/distributed_runtime/request_id.h" |
32 | #include "tensorflow/core/distributed_runtime/scheduler.h" |
33 | #include "tensorflow/core/distributed_runtime/worker_cache.h" |
34 | #include "tensorflow/core/distributed_runtime/worker_interface.h" |
35 | #include "tensorflow/core/framework/allocation_description.pb.h" |
36 | #include "tensorflow/core/framework/collective.h" |
37 | #include "tensorflow/core/framework/cost_graph.pb.h" |
38 | #include "tensorflow/core/framework/graph_def_util.h" |
39 | #include "tensorflow/core/framework/node_def.pb.h" |
40 | #include "tensorflow/core/framework/node_def_util.h" |
41 | #include "tensorflow/core/framework/tensor.h" |
42 | #include "tensorflow/core/framework/tensor.pb.h" |
43 | #include "tensorflow/core/framework/tensor_description.pb.h" |
44 | #include "tensorflow/core/graph/graph_partition.h" |
45 | #include "tensorflow/core/graph/tensor_id.h" |
46 | #include "tensorflow/core/lib/core/notification.h" |
47 | #include "tensorflow/core/lib/core/refcount.h" |
48 | #include "tensorflow/core/lib/core/status.h" |
49 | #include "tensorflow/core/lib/gtl/cleanup.h" |
50 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
51 | #include "tensorflow/core/lib/gtl/map_util.h" |
52 | #include "tensorflow/core/lib/random/random.h" |
53 | #include "tensorflow/core/lib/strings/numbers.h" |
54 | #include "tensorflow/core/lib/strings/str_util.h" |
55 | #include "tensorflow/core/lib/strings/strcat.h" |
56 | #include "tensorflow/core/lib/strings/stringprintf.h" |
57 | #include "tensorflow/core/platform/blocking_counter.h" |
58 | #include "tensorflow/core/platform/env.h" |
59 | #include "tensorflow/core/platform/logging.h" |
60 | #include "tensorflow/core/platform/macros.h" |
61 | #include "tensorflow/core/platform/mutex.h" |
62 | #include "tensorflow/core/platform/tracing.h" |
63 | #include "tensorflow/core/protobuf/config.pb.h" |
64 | #include "tensorflow/core/protobuf/coordination_config.pb.h" |
65 | #include "tensorflow/core/public/session_options.h" |
66 | #include "tensorflow/core/util/device_name_utils.h" |
67 | |
68 | namespace tensorflow { |
69 | |
70 | // MasterSession wraps ClientGraph in a reference counted object. |
71 | // This way, MasterSession can clear up the cache mapping Run requests to |
72 | // compiled graphs while the compiled graph is still being used. |
73 | // |
74 | // TODO(zhifengc): Cleanup this class. It's becoming messy. |
75 | class MasterSession::ReffedClientGraph : public core::RefCounted { |
76 | public: |
77 | ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts, |
78 | std::unique_ptr<ClientGraph> client_graph, |
79 | const SessionOptions& session_opts, |
80 | const StatsPublisherFactory& stats_publisher_factory, |
81 | bool is_partial, WorkerCacheInterface* worker_cache, |
82 | bool should_deregister) |
83 | : session_handle_(handle), |
84 | bg_opts_(bopts), |
85 | client_graph_before_register_(std::move(client_graph)), |
86 | session_opts_(session_opts), |
87 | is_partial_(is_partial), |
88 | callable_opts_(bopts.callable_options), |
89 | worker_cache_(worker_cache), |
90 | should_deregister_(should_deregister), |
91 | collective_graph_key_( |
92 | client_graph_before_register_->collective_graph_key) { |
93 | VLOG(1) << "Created ReffedClientGraph for node with " |
94 | << client_graph_before_register_->graph.num_node_ids(); |
95 | |
96 | stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts); |
97 | |
98 | // Initialize a name to node map for processing device stats. |
99 | for (Node* n : client_graph_before_register_->graph.nodes()) { |
100 | name_to_node_details_.emplace( |
101 | n->name(), |
102 | NodeDetails(n->type_string(), |
103 | strings::StrCat( |
104 | "(" , absl::StrJoin(n->requested_inputs(), ", " )))); |
105 | } |
106 | } |
107 | |
108 | ~ReffedClientGraph() override { |
109 | if (should_deregister_) { |
110 | DeregisterPartitions(); |
111 | } else { |
112 | for (Part& part : partitions_) { |
113 | worker_cache_->ReleaseWorker(part.name, part.worker); |
114 | } |
115 | } |
116 | } |
117 | |
118 | const CallableOptions& callable_options() { return callable_opts_; } |
119 | |
120 | const BuildGraphOptions& build_graph_options() { return bg_opts_; } |
121 | |
122 | int64_t collective_graph_key() { return collective_graph_key_; } |
123 | |
124 | std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step, |
125 | int64_t execution_count, |
126 | const RunOptions& ropts) { |
127 | return stats_publisher_->GetProfileHandler(step, execution_count, ropts); |
128 | } |
129 | |
130 | int64_t get_and_increment_execution_count() { |
131 | return execution_count_.fetch_add(1); |
132 | } |
133 | |
134 | // Turn RPC logging on or off, both at the WorkerCache used by this |
135 | // master process, and at each remote worker in use for the current |
136 | // partitions. |
137 | void SetRPCLogging(bool active) { |
138 | worker_cache_->SetLogging(active); |
139 | // Logging is a best-effort activity, so we make async calls to turn |
140 | // it on/off and don't make use of the responses. |
141 | for (auto& p : partitions_) { |
142 | LoggingRequest* req = new LoggingRequest; |
143 | if (active) { |
144 | req->set_enable_rpc_logging(true); |
145 | } else { |
146 | req->set_disable_rpc_logging(true); |
147 | } |
148 | LoggingResponse* resp = new LoggingResponse; |
149 | Ref(); |
150 | p.worker->LoggingAsync(req, resp, [this, req, resp](const Status& s) { |
151 | delete req; |
152 | delete resp; |
153 | // ReffedClientGraph owns p.worker so we need to hold a ref to |
154 | // ensure that the method doesn't attempt to access p.worker after |
155 | // ReffedClient graph has deleted it. |
156 | // TODO(suharshs): Simplify this ownership model. |
157 | Unref(); |
158 | }); |
159 | } |
160 | } |
161 | |
162 | // Retrieve all RPC logs data accumulated for the current step, both |
163 | // from the local WorkerCache in use by this master process and from |
164 | // all the remote workers executing the remote partitions. |
165 | void RetrieveLogs(int64_t step_id, StepStats* ss) { |
166 | // Get the local data first, because it sets *ss without merging. |
167 | worker_cache_->RetrieveLogs(step_id, ss); |
168 | |
169 | // Then merge in data from all the remote workers. |
170 | LoggingRequest req; |
171 | req.add_fetch_step_id(step_id); |
172 | int waiting_for = partitions_.size(); |
173 | if (waiting_for > 0) { |
174 | mutex scoped_mu; |
175 | BlockingCounter all_done(waiting_for); |
176 | for (auto& p : partitions_) { |
177 | LoggingResponse* resp = new LoggingResponse; |
178 | p.worker->LoggingAsync( |
179 | &req, resp, |
180 | [step_id, ss, resp, &scoped_mu, &all_done](const Status& s) { |
181 | { |
182 | mutex_lock l(scoped_mu); |
183 | if (s.ok()) { |
184 | for (auto& lss : resp->step()) { |
185 | if (step_id != lss.step_id()) { |
186 | LOG(ERROR) << "Wrong step_id in LoggingResponse" ; |
187 | continue; |
188 | } |
189 | ss->MergeFrom(lss.step_stats()); |
190 | } |
191 | } |
192 | delete resp; |
193 | } |
194 | // Must not decrement all_done until out of critical section where |
195 | // *ss is updated. |
196 | all_done.DecrementCount(); |
197 | }); |
198 | } |
199 | all_done.Wait(); |
200 | } |
201 | } |
202 | |
203 | // Local execution methods. |
204 | |
205 | // Partitions the graph into subgraphs and registers them on |
206 | // workers. |
207 | Status RegisterPartitions(PartitionOptions popts); |
208 | |
209 | // Runs one step of all partitions. |
210 | Status RunPartitions(const MasterEnv* env, int64_t step_id, |
211 | int64_t execution_count, PerStepState* pss, |
212 | CallOptions* opts, const RunStepRequestWrapper& req, |
213 | MutableRunStepResponseWrapper* resp, |
214 | CancellationManager* cm, const bool is_last_partial_run); |
215 | Status RunPartitions(const MasterEnv* env, int64_t step_id, |
216 | int64_t execution_count, PerStepState* pss, |
217 | CallOptions* call_opts, const RunCallableRequest& req, |
218 | RunCallableResponse* resp, CancellationManager* cm); |
219 | |
220 | // Calls workers to cleanup states for the step "step_id". Calls |
221 | // `done` when all cleanup RPCs have completed. |
222 | void CleanupPartitionsAsync(int64_t step_id, StatusCallback done); |
223 | |
224 | // Post-processing of any runtime statistics gathered during execution. |
225 | void ProcessStats(int64_t step_id, PerStepState* pss, ProfileHandler* ph, |
226 | const RunOptions& options, RunMetadata* resp); |
227 | void ProcessDeviceStats(ProfileHandler* ph, const DeviceStepStats& ds, |
228 | bool is_rpc); |
229 | // Checks that the requested fetches can be computed from the provided feeds. |
230 | Status CheckFetches(const RunStepRequestWrapper& req, |
231 | const RunState* run_state, |
232 | GraphExecutionState* execution_state); |
233 | |
234 | private: |
235 | const string session_handle_; |
236 | const BuildGraphOptions bg_opts_; |
237 | |
238 | // NOTE(mrry): This pointer will be null after `RegisterPartitions()` returns. |
239 | std::unique_ptr<ClientGraph> client_graph_before_register_ TF_GUARDED_BY(mu_); |
240 | const SessionOptions session_opts_; |
241 | const bool is_partial_; |
242 | const CallableOptions callable_opts_; |
243 | WorkerCacheInterface* const worker_cache_; // Not owned. |
244 | |
245 | struct NodeDetails { |
246 | explicit NodeDetails(string type_string, string detail_text) |
247 | : type_string(std::move(type_string)), |
248 | detail_text(std::move(detail_text)) {} |
249 | const string type_string; |
250 | const string detail_text; |
251 | }; |
252 | std::unordered_map<string, NodeDetails> name_to_node_details_; |
253 | |
254 | const bool should_deregister_; |
255 | const int64_t collective_graph_key_; |
256 | std::atomic<int64_t> execution_count_ = {0}; |
257 | |
258 | // Graph partitioned into per-location subgraphs. |
259 | struct Part { |
260 | // Worker name. |
261 | string name; |
262 | |
263 | // Maps feed names to rendezvous keys. Empty most of the time. |
264 | std::unordered_map<string, string> feed_key; |
265 | |
266 | // Maps rendezvous keys to fetch names. Empty most of the time. |
267 | std::unordered_map<string, string> key_fetch; |
268 | |
269 | // The interface to the worker. Owned. |
270 | WorkerInterface* worker = nullptr; |
271 | |
272 | // After registration with the worker, graph_handle identifies |
273 | // this partition on the worker. |
274 | string graph_handle; |
275 | |
276 | Part() : feed_key(3), key_fetch(3) {} |
277 | }; |
278 | |
279 | // partitions_ is immutable after RegisterPartitions() call |
280 | // finishes. RunPartitions() can access partitions_ safely without |
281 | // acquiring locks. |
282 | std::vector<Part> partitions_; |
283 | |
284 | mutable mutex mu_; |
285 | |
286 | // Partition initialization and registration only needs to happen |
287 | // once. `!client_graph_before_register_ && !init_done_.HasBeenNotified()` |
288 | // indicates the initialization is ongoing. |
289 | Notification init_done_; |
290 | |
291 | // init_result_ remembers the initialization error if any. |
292 | Status init_result_ TF_GUARDED_BY(mu_); |
293 | |
294 | std::unique_ptr<StatsPublisherInterface> stats_publisher_; |
295 | |
296 | string DetailText(const NodeDetails& details, const NodeExecStats& stats) { |
297 | int64_t tot = 0; |
298 | for (auto& no : stats.output()) { |
299 | tot += no.tensor_description().allocation_description().requested_bytes(); |
300 | } |
301 | string bytes; |
302 | if (tot >= 0.1 * 1048576.0) { |
303 | bytes = strings::Printf("[%.1fMB] " , tot / 1048576.0); |
304 | } |
305 | return strings::StrCat(bytes, stats.node_name(), " = " , details.type_string, |
306 | details.detail_text); |
307 | } |
308 | |
309 | // Send/Recv nodes that are the result of client-added |
310 | // feeds and fetches must be tracked so that the tensors |
311 | // can be added to the local rendezvous. |
312 | static void TrackFeedsAndFetches(Part* part, const GraphDef& graph_def, |
313 | const PartitionOptions& popts); |
314 | |
315 | // The actual graph partitioning and registration implementation. |
316 | Status DoBuildPartitions( |
317 | PartitionOptions popts, ClientGraph* client_graph, |
318 | std::unordered_map<string, GraphDef>* out_partitions); |
319 | Status DoRegisterPartitions( |
320 | const PartitionOptions& popts, |
321 | std::unordered_map<string, GraphDef> graph_partitions); |
322 | |
323 | // Prepares a number of calls to workers. One call per partition. |
324 | // This is a generic method that handles Run, PartialRun, and RunCallable. |
325 | template <class FetchListType, class ClientRequestType, |
326 | class ClientResponseType> |
327 | Status RunPartitionsHelper( |
328 | const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds, |
329 | const FetchListType& fetches, const MasterEnv* env, int64_t step_id, |
330 | int64_t execution_count, PerStepState* pss, CallOptions* call_opts, |
331 | const ClientRequestType& req, ClientResponseType* resp, |
332 | CancellationManager* cm, bool is_last_partial_run); |
333 | |
334 | // Deregisters the partitions on the workers. Called in the |
335 | // destructor and does not wait for the rpc completion. |
336 | void DeregisterPartitions(); |
337 | |
338 | TF_DISALLOW_COPY_AND_ASSIGN(ReffedClientGraph); |
339 | }; |
340 | |
341 | Status MasterSession::ReffedClientGraph::RegisterPartitions( |
342 | PartitionOptions popts) { |
343 | { // Ensure register once. |
344 | mu_.lock(); |
345 | if (client_graph_before_register_) { |
346 | // The `ClientGraph` is no longer needed after partitions are registered. |
347 | // Since it can account for a large amount of memory, we consume it here, |
348 | // and it will be freed after concluding with registration. |
349 | |
350 | std::unique_ptr<ClientGraph> client_graph; |
351 | std::swap(client_graph_before_register_, client_graph); |
352 | mu_.unlock(); |
353 | std::unordered_map<string, GraphDef> graph_defs; |
354 | popts.flib_def = client_graph->flib_def.get(); |
355 | Status s = DoBuildPartitions(popts, client_graph.get(), &graph_defs); |
356 | if (s.ok()) { |
357 | // NOTE(mrry): The pointers in `graph_defs_for_publishing` do not remain |
358 | // valid after the call to DoRegisterPartitions begins, so |
359 | // `stats_publisher_` must make a copy if it wants to retain the |
360 | // GraphDef objects. |
361 | std::vector<const GraphDef*> graph_defs_for_publishing; |
362 | graph_defs_for_publishing.reserve(partitions_.size()); |
363 | for (const auto& name_def : graph_defs) { |
364 | graph_defs_for_publishing.push_back(&name_def.second); |
365 | } |
366 | stats_publisher_->PublishGraphProto(graph_defs_for_publishing); |
367 | s = DoRegisterPartitions(popts, std::move(graph_defs)); |
368 | } |
369 | mu_.lock(); |
370 | init_result_ = s; |
371 | init_done_.Notify(); |
372 | } else { |
373 | mu_.unlock(); |
374 | init_done_.WaitForNotification(); |
375 | mu_.lock(); |
376 | } |
377 | const Status result = init_result_; |
378 | mu_.unlock(); |
379 | return result; |
380 | } |
381 | } |
382 | |
383 | static string SplitByWorker(const Node* node) { |
384 | string task; |
385 | string device; |
386 | CHECK(DeviceNameUtils::SplitDeviceName(node->assigned_device_name(), &task, |
387 | &device)) |
388 | << "node: " << node->name() << " dev: " << node->assigned_device_name(); |
389 | return task; |
390 | } |
391 | |
392 | void MasterSession::ReffedClientGraph::TrackFeedsAndFetches( |
393 | Part* part, const GraphDef& graph_def, const PartitionOptions& popts) { |
394 | for (int i = 0; i < graph_def.node_size(); ++i) { |
395 | const NodeDef& ndef = graph_def.node(i); |
396 | const bool is_recv = ndef.op() == "_Recv" ; |
397 | const bool is_send = ndef.op() == "_Send" ; |
398 | |
399 | if (is_recv || is_send) { |
400 | // Only send/recv nodes that were added as feeds and fetches |
401 | // (client-terminated) should be tracked. Other send/recv nodes |
402 | // are for transferring data between partitions / memory spaces. |
403 | bool client_terminated; |
404 | TF_CHECK_OK(GetNodeAttr(ndef, "client_terminated" , &client_terminated)); |
405 | if (client_terminated) { |
406 | string name; |
407 | TF_CHECK_OK(GetNodeAttr(ndef, "tensor_name" , &name)); |
408 | string send_device; |
409 | TF_CHECK_OK(GetNodeAttr(ndef, "send_device" , &send_device)); |
410 | string recv_device; |
411 | TF_CHECK_OK(GetNodeAttr(ndef, "recv_device" , &recv_device)); |
412 | uint64 send_device_incarnation; |
413 | TF_CHECK_OK( |
414 | GetNodeAttr(ndef, "send_device_incarnation" , |
415 | reinterpret_cast<int64_t*>(&send_device_incarnation))); |
416 | const string& key = |
417 | Rendezvous::CreateKey(send_device, send_device_incarnation, |
418 | recv_device, name, FrameAndIter(0, 0)); |
419 | |
420 | if (is_recv) { |
421 | part->feed_key.insert({name, key}); |
422 | } else { |
423 | part->key_fetch.insert({key, name}); |
424 | } |
425 | } |
426 | } |
427 | } |
428 | } |
429 | |
430 | Status MasterSession::ReffedClientGraph::DoBuildPartitions( |
431 | PartitionOptions popts, ClientGraph* client_graph, |
432 | std::unordered_map<string, GraphDef>* out_partitions) { |
433 | if (popts.need_to_record_start_times) { |
434 | CostModel cost_model(true); |
435 | cost_model.InitFromGraph(client_graph->graph); |
436 | // TODO(yuanbyu): Use the real cost model. |
437 | // execution_state_->MergeFromGlobal(&cost_model); |
438 | SlackAnalysis sa(&client_graph->graph, &cost_model); |
439 | sa.ComputeAsap(&popts.start_times); |
440 | } |
441 | |
442 | // Partition the graph. |
443 | return Partition(popts, &client_graph->graph, out_partitions); |
444 | } |
445 | |
446 | Status MasterSession::ReffedClientGraph::DoRegisterPartitions( |
447 | const PartitionOptions& popts, |
448 | std::unordered_map<string, GraphDef> graph_partitions) { |
449 | partitions_.reserve(graph_partitions.size()); |
450 | Status s; |
451 | for (auto& name_def : graph_partitions) { |
452 | partitions_.emplace_back(); |
453 | Part* part = &partitions_.back(); |
454 | part->name = name_def.first; |
455 | TrackFeedsAndFetches(part, name_def.second, popts); |
456 | part->worker = worker_cache_->GetOrCreateWorker(part->name); |
457 | if (part->worker == nullptr) { |
458 | s = errors::NotFound("worker " , part->name); |
459 | break; |
460 | } |
461 | } |
462 | if (!s.ok()) { |
463 | for (Part& part : partitions_) { |
464 | worker_cache_->ReleaseWorker(part.name, part.worker); |
465 | part.worker = nullptr; |
466 | } |
467 | return s; |
468 | } |
469 | struct Call { |
470 | RegisterGraphRequest req; |
471 | RegisterGraphResponse resp; |
472 | Status status; |
473 | }; |
474 | const int num = partitions_.size(); |
475 | gtl::InlinedVector<Call, 4> calls(num); |
476 | BlockingCounter done(num); |
477 | for (int i = 0; i < num; ++i) { |
478 | const Part& part = partitions_[i]; |
479 | Call* c = &calls[i]; |
480 | c->req.set_session_handle(session_handle_); |
481 | c->req.set_create_worker_session_called(!should_deregister_); |
482 | c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]); |
483 | StripDefaultAttributes(*OpRegistry::Global(), |
484 | c->req.mutable_graph_def()->mutable_node()); |
485 | *c->req.mutable_config_proto() = session_opts_.config; |
486 | *c->req.mutable_graph_options() = session_opts_.config.graph_options(); |
487 | *c->req.mutable_debug_options() = |
488 | callable_opts_.run_options().debug_options(); |
489 | c->req.set_collective_graph_key(collective_graph_key_); |
490 | VLOG(2) << "Register " << c->req.graph_def().DebugString(); |
491 | auto cb = [c, &done](const Status& s) { |
492 | c->status = s; |
493 | done.DecrementCount(); |
494 | }; |
495 | part.worker->RegisterGraphAsync(&c->req, &c->resp, cb); |
496 | } |
497 | done.Wait(); |
498 | for (int i = 0; i < num; ++i) { |
499 | Call* c = &calls[i]; |
500 | s.Update(c->status); |
501 | partitions_[i].graph_handle = c->resp.graph_handle(); |
502 | } |
503 | return s; |
504 | } |
505 | |
506 | namespace { |
507 | // Helper class to manage "num" parallel RunGraph calls. |
508 | class RunManyGraphs { |
509 | public: |
510 | explicit RunManyGraphs(int num) : calls_(num), pending_(num) {} |
511 | |
512 | ~RunManyGraphs() {} |
513 | |
514 | // Returns the index-th call. |
515 | struct Call { |
516 | CallOptions opts; |
517 | const string* worker_name; |
518 | std::atomic<bool> done{false}; |
519 | std::unique_ptr<MutableRunGraphRequestWrapper> req; |
520 | std::unique_ptr<MutableRunGraphResponseWrapper> resp; |
521 | }; |
522 | Call* get(int index) { return &calls_[index]; } |
523 | |
524 | // When the index-th call is done, updates the overall status. |
525 | void WhenDone(int index, const Status& s) { |
526 | TRACEPRINTF("Partition %d %s" , index, s.ToString().c_str()); |
527 | Call* call = get(index); |
528 | call->done = true; |
529 | auto resp = call->resp.get(); |
530 | if (resp->status_code() != error::Code::OK) { |
531 | // resp->status_code will only be non-OK if s.ok(). |
532 | mutex_lock l(mu_); |
533 | Status resp_status = call->resp->status(); |
534 | ReportBadStatus(errors::CreateWithUpdatedMessage( |
535 | resp_status, strings::StrCat("From " , *call->worker_name, ":\n" , |
536 | resp_status.error_message()))); |
537 | } else if (!s.ok()) { |
538 | mutex_lock l(mu_); |
539 | ReportBadStatus(errors::CreateWithUpdatedMessage( |
540 | s, strings::StrCat("From " , *call->worker_name, ":\n" , |
541 | s.error_message()))); |
542 | } |
543 | pending_.DecrementCount(); |
544 | } |
545 | |
546 | void StartCancel() { |
547 | mutex_lock l(mu_); |
548 | ReportBadStatus(errors::Cancelled("RunManyGraphs" )); |
549 | } |
550 | |
551 | void Wait() { |
552 | // Check the error status every 60 seconds in other to print a log message |
553 | // in the event of a hang. |
554 | const std::chrono::milliseconds kCheckErrorPeriod(1000 * 60); |
555 | while (true) { |
556 | if (pending_.WaitFor(kCheckErrorPeriod)) { |
557 | return; |
558 | } |
559 | if (!status().ok()) { |
560 | break; |
561 | } |
562 | } |
563 | |
564 | // The step has failed. Wait for another 60 seconds before diagnosing a |
565 | // hang. |
566 | DCHECK(!status().ok()); |
567 | if (pending_.WaitFor(kCheckErrorPeriod)) { |
568 | return; |
569 | } |
570 | LOG(ERROR) |
571 | << "RunStep still blocked after 60 seconds. Failed with error status: " |
572 | << status(); |
573 | for (const Call& call : calls_) { |
574 | if (!call.done) { |
575 | LOG(ERROR) << "- No response from RunGraph call to worker: " |
576 | << *call.worker_name; |
577 | } |
578 | } |
579 | pending_.Wait(); |
580 | } |
581 | |
582 | Status status() const { |
583 | mutex_lock l(mu_); |
584 | // Concat status objects in this StatusGroup to get the aggregated status, |
585 | // as each status in status_group_ is already summarized status. |
586 | return status_group_.as_concatenated_status(); |
587 | } |
588 | |
589 | private: |
590 | gtl::InlinedVector<Call, 4> calls_; |
591 | |
592 | BlockingCounter pending_; |
593 | mutable mutex mu_; |
594 | StatusGroup status_group_ TF_GUARDED_BY(mu_); |
595 | bool cancel_issued_ TF_GUARDED_BY(mu_) = false; |
596 | |
597 | void ReportBadStatus(const Status& s) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
598 | VLOG(1) << "Master received error status " << s; |
599 | if (!cancel_issued_ && !StatusGroup::IsDerived(s)) { |
600 | // Only start cancelling other workers upon receiving a non-derived |
601 | // error |
602 | cancel_issued_ = true; |
603 | |
604 | VLOG(1) << "Master received error report. Cancelling remaining workers." ; |
605 | for (Call& call : calls_) { |
606 | call.opts.StartCancel(); |
607 | } |
608 | } |
609 | |
610 | status_group_.Update(s); |
611 | } |
612 | |
613 | TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs); |
614 | }; |
615 | |
616 | Status AddSendFromClientRequest(const RunStepRequestWrapper& client_req, |
617 | MutableRunGraphRequestWrapper* worker_req, |
618 | size_t index, const string& send_key) { |
619 | return worker_req->AddSendFromRunStepRequest(client_req, index, send_key); |
620 | } |
621 | |
622 | Status AddSendFromClientRequest(const RunCallableRequest& client_req, |
623 | MutableRunGraphRequestWrapper* worker_req, |
624 | size_t index, const string& send_key) { |
625 | return worker_req->AddSendFromRunCallableRequest(client_req, index, send_key); |
626 | } |
627 | |
628 | // TODO(mrry): Add a full-fledged wrapper that avoids TensorProto copies for |
629 | // in-process messages. |
630 | struct RunCallableResponseWrapper { |
631 | RunCallableResponse* resp; // Not owned. |
632 | std::unordered_map<string, TensorProto> fetch_key_to_protos; |
633 | |
634 | RunMetadata* mutable_metadata() { return resp->mutable_metadata(); } |
635 | |
636 | Status AddTensorFromRunGraphResponse( |
637 | const string& tensor_name, MutableRunGraphResponseWrapper* worker_resp, |
638 | size_t index) { |
639 | return worker_resp->RecvValue(index, &fetch_key_to_protos[tensor_name]); |
640 | } |
641 | }; |
642 | } // namespace |
643 | |
644 | template <class FetchListType, class ClientRequestType, |
645 | class ClientResponseType> |
646 | Status MasterSession::ReffedClientGraph::RunPartitionsHelper( |
647 | const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds, |
648 | const FetchListType& fetches, const MasterEnv* env, int64_t step_id, |
649 | int64_t execution_count, PerStepState* pss, CallOptions* call_opts, |
650 | const ClientRequestType& req, ClientResponseType* resp, |
651 | CancellationManager* cm, bool is_last_partial_run) { |
652 | // Collect execution cost stats on a smoothly decreasing frequency. |
653 | ExecutorOpts exec_opts; |
654 | if (pss->report_tensor_allocations_upon_oom) { |
655 | exec_opts.set_report_tensor_allocations_upon_oom(true); |
656 | } |
657 | if (pss->collect_costs) { |
658 | exec_opts.set_record_costs(true); |
659 | } |
660 | if (pss->collect_timeline) { |
661 | exec_opts.set_record_timeline(true); |
662 | } |
663 | if (pss->collect_rpcs) { |
664 | SetRPCLogging(true); |
665 | } |
666 | if (pss->collect_partition_graphs) { |
667 | exec_opts.set_record_partition_graphs(true); |
668 | } |
669 | if (pss->collect_costs || pss->collect_timeline) { |
670 | pss->step_stats.resize(partitions_.size()); |
671 | } |
672 | |
673 | const int num = partitions_.size(); |
674 | RunManyGraphs calls(num); |
675 | |
676 | for (int i = 0; i < num; ++i) { |
677 | const Part& part = partitions_[i]; |
678 | RunManyGraphs::Call* c = calls.get(i); |
679 | c->worker_name = &part.name; |
680 | c->req.reset(part.worker->CreateRunGraphRequest()); |
681 | c->resp.reset(part.worker->CreateRunGraphResponse()); |
682 | if (is_partial_) { |
683 | c->req->set_is_partial(is_partial_); |
684 | c->req->set_is_last_partial_run(is_last_partial_run); |
685 | } |
686 | c->req->set_session_handle(session_handle_); |
687 | c->req->set_create_worker_session_called(!should_deregister_); |
688 | c->req->set_graph_handle(part.graph_handle); |
689 | c->req->set_step_id(step_id); |
690 | *c->req->mutable_exec_opts() = exec_opts; |
691 | c->req->set_store_errors_in_response_body(true); |
692 | c->req->set_request_id(GetUniqueRequestId()); |
693 | // If any feeds are provided, send the feed values together |
694 | // in the RunGraph request. |
695 | // In the partial case, we only want to include feeds provided in the req. |
696 | // In the non-partial case, all feeds in the request are in the part. |
697 | // We keep these as separate paths for now, to ensure we aren't |
698 | // inadvertently slowing down the normal run path. |
699 | if (is_partial_) { |
700 | for (const auto& name_index : feeds) { |
701 | const auto iter = part.feed_key.find(string(name_index.first)); |
702 | if (iter == part.feed_key.end()) { |
703 | // The provided feed must be for a different partition. |
704 | continue; |
705 | } |
706 | const string& key = iter->second; |
707 | TF_RETURN_IF_ERROR(AddSendFromClientRequest(req, c->req.get(), |
708 | name_index.second, key)); |
709 | } |
710 | // TODO(suharshs): Make a map from feed to fetch_key to make this faster. |
711 | // For now, we just iterate through partitions to find the matching key. |
712 | for (const string& req_fetch : fetches) { |
713 | for (const auto& key_fetch : part.key_fetch) { |
714 | if (key_fetch.second == req_fetch) { |
715 | c->req->add_recv_key(key_fetch.first); |
716 | break; |
717 | } |
718 | } |
719 | } |
720 | } else { |
721 | for (const auto& feed_key : part.feed_key) { |
722 | const string& feed = feed_key.first; |
723 | const string& key = feed_key.second; |
724 | auto iter = feeds.find(feed); |
725 | if (iter == feeds.end()) { |
726 | return errors::Internal("No feed index found for feed: " , feed); |
727 | } |
728 | const int64_t feed_index = iter->second; |
729 | TF_RETURN_IF_ERROR( |
730 | AddSendFromClientRequest(req, c->req.get(), feed_index, key)); |
731 | } |
732 | for (const auto& key_fetch : part.key_fetch) { |
733 | const string& key = key_fetch.first; |
734 | c->req->add_recv_key(key); |
735 | } |
736 | } |
737 | } |
738 | |
739 | // Issues RunGraph calls. |
740 | for (int i = 0; i < num; ++i) { |
741 | const Part& part = partitions_[i]; |
742 | RunManyGraphs::Call* call = calls.get(i); |
743 | TRACEPRINTF("Partition %d %s" , i, part.name.c_str()); |
744 | part.worker->RunGraphAsync( |
745 | &call->opts, call->req.get(), call->resp.get(), |
746 | std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1)); |
747 | } |
748 | |
749 | // Waits for the RunGraph calls. |
750 | call_opts->SetCancelCallback([&calls]() { |
751 | LOG(INFO) << "Client requested cancellation for RunStep, cancelling " |
752 | "worker operations." ; |
753 | calls.StartCancel(); |
754 | }); |
755 | auto token = cm->get_cancellation_token(); |
756 | const bool success = |
757 | cm->RegisterCallback(token, [&calls]() { calls.StartCancel(); }); |
758 | if (!success) { |
759 | calls.StartCancel(); |
760 | } |
761 | calls.Wait(); |
762 | call_opts->ClearCancelCallback(); |
763 | if (success) { |
764 | cm->DeregisterCallback(token); |
765 | } else { |
766 | return errors::Cancelled("Step was cancelled" ); |
767 | } |
768 | TF_RETURN_IF_ERROR(calls.status()); |
769 | |
770 | // Collects fetches and metadata. |
771 | Status status; |
772 | for (int i = 0; i < num; ++i) { |
773 | const Part& part = partitions_[i]; |
774 | MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get(); |
775 | for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) { |
776 | auto iter = part.key_fetch.find(run_graph_resp->recv_key(j)); |
777 | if (iter == part.key_fetch.end()) { |
778 | status.Update(errors::Internal("Unexpected fetch key: " , |
779 | run_graph_resp->recv_key(j))); |
780 | break; |
781 | } |
782 | const string& fetch = iter->second; |
783 | status.Update( |
784 | resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j)); |
785 | if (!status.ok()) { |
786 | break; |
787 | } |
788 | } |
789 | if (pss->collect_timeline) { |
790 | pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats()); |
791 | } |
792 | if (pss->collect_costs) { |
793 | CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph(); |
794 | for (int j = 0; j < cost_graph->node_size(); ++j) { |
795 | resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap( |
796 | cost_graph->mutable_node(j)); |
797 | } |
798 | } |
799 | if (pss->collect_partition_graphs) { |
800 | protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs = |
801 | resp->mutable_metadata()->mutable_partition_graphs(); |
802 | for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) { |
803 | partition_graph_defs->Add()->Swap( |
804 | run_graph_resp->mutable_partition_graph(i)); |
805 | } |
806 | } |
807 | } |
808 | return status; |
809 | } |
810 | |
811 | Status MasterSession::ReffedClientGraph::RunPartitions( |
812 | const MasterEnv* env, int64_t step_id, int64_t execution_count, |
813 | PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req, |
814 | MutableRunStepResponseWrapper* resp, CancellationManager* cm, |
815 | const bool is_last_partial_run) { |
816 | VLOG(2) << "RunPartitions step_id " << step_id << " execution_count " |
817 | << execution_count; |
818 | // Maps the names of fed tensors to their index in `req`. |
819 | std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3); |
820 | for (size_t i = 0; i < req.num_feeds(); ++i) { |
821 | if (!feeds.insert({req.feed_name(i), i}).second) { |
822 | return errors::InvalidArgument("Duplicated feeds: " , req.feed_name(i)); |
823 | } |
824 | } |
825 | |
826 | std::vector<string> fetches; |
827 | fetches.reserve(req.num_fetches()); |
828 | for (size_t i = 0; i < req.num_fetches(); ++i) { |
829 | fetches.push_back(req.fetch_name(i)); |
830 | } |
831 | |
832 | return RunPartitionsHelper(feeds, fetches, env, step_id, execution_count, pss, |
833 | call_opts, req, resp, cm, is_last_partial_run); |
834 | } |
835 | |
836 | Status MasterSession::ReffedClientGraph::RunPartitions( |
837 | const MasterEnv* env, int64_t step_id, int64_t execution_count, |
838 | PerStepState* pss, CallOptions* call_opts, const RunCallableRequest& req, |
839 | RunCallableResponse* resp, CancellationManager* cm) { |
840 | VLOG(2) << "RunPartitions step_id " << step_id << " execution_count " |
841 | << execution_count; |
842 | // Maps the names of fed tensors to their index in `req`. |
843 | std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3); |
844 | for (size_t i = 0, end = callable_opts_.feed_size(); i < end; ++i) { |
845 | if (!feeds.insert({callable_opts_.feed(i), i}).second) { |
846 | // MakeCallable will fail if there are two feeds with the same name. |
847 | return errors::Internal("Duplicated feeds in callable: " , |
848 | callable_opts_.feed(i)); |
849 | } |
850 | } |
851 | |
852 | // Create a wrapped response object to collect the fetched values and |
853 | // rearrange them for the RunCallableResponse. |
854 | RunCallableResponseWrapper wrapped_resp; |
855 | wrapped_resp.resp = resp; |
856 | |
857 | TF_RETURN_IF_ERROR(RunPartitionsHelper( |
858 | feeds, callable_opts_.fetch(), env, step_id, execution_count, pss, |
859 | call_opts, req, &wrapped_resp, cm, false /* is_last_partial_run */)); |
860 | |
861 | // Collects fetches. |
862 | for (const string& fetch : callable_opts_.fetch()) { |
863 | TensorProto* fetch_proto = resp->mutable_fetch()->Add(); |
864 | auto iter = wrapped_resp.fetch_key_to_protos.find(fetch); |
865 | if (iter == wrapped_resp.fetch_key_to_protos.end()) { |
866 | return errors::Internal("Worker did not return a value for fetch: " , |
867 | fetch); |
868 | } |
869 | fetch_proto->Swap(&iter->second); |
870 | } |
871 | return OkStatus(); |
872 | } |
873 | |
874 | namespace { |
875 | |
876 | class CleanupBroadcastHelper { |
877 | public: |
878 | CleanupBroadcastHelper(int64_t step_id, int num_calls, StatusCallback done) |
879 | : resps_(num_calls), num_pending_(num_calls), done_(std::move(done)) { |
880 | req_.set_step_id(step_id); |
881 | } |
882 | |
883 | // Returns a non-owned pointer to a request buffer for all calls. |
884 | CleanupGraphRequest* request() { return &req_; } |
885 | |
886 | // Returns a non-owned pointer to a response buffer for the ith call. |
887 | CleanupGraphResponse* response(int i) { return &resps_[i]; } |
888 | |
889 | // Called when the ith response is received. |
890 | void call_done(int i, const Status& s) { |
891 | bool run_callback = false; |
892 | Status status_copy; |
893 | { |
894 | mutex_lock l(mu_); |
895 | status_.Update(s); |
896 | if (--num_pending_ == 0) { |
897 | run_callback = true; |
898 | status_copy = status_; |
899 | } |
900 | } |
901 | if (run_callback) { |
902 | done_(status_copy); |
903 | // This is the last call, so delete the helper object. |
904 | delete this; |
905 | } |
906 | } |
907 | |
908 | private: |
909 | // A single request shared between all workers. |
910 | CleanupGraphRequest req_; |
911 | // One response buffer for each worker. |
912 | gtl::InlinedVector<CleanupGraphResponse, 4> resps_; |
913 | |
914 | mutex mu_; |
915 | // Number of requests remaining to be collected. |
916 | int num_pending_ TF_GUARDED_BY(mu_); |
917 | // Aggregate status of the operation. |
918 | Status status_ TF_GUARDED_BY(mu_); |
919 | // Callback to be called when all operations complete. |
920 | StatusCallback done_; |
921 | |
922 | TF_DISALLOW_COPY_AND_ASSIGN(CleanupBroadcastHelper); |
923 | }; |
924 | |
925 | } // namespace |
926 | |
927 | void MasterSession::ReffedClientGraph::CleanupPartitionsAsync( |
928 | int64_t step_id, StatusCallback done) { |
929 | const int num = partitions_.size(); |
930 | // Helper object will be deleted when the final call completes. |
931 | CleanupBroadcastHelper* helper = |
932 | new CleanupBroadcastHelper(step_id, num, std::move(done)); |
933 | for (int i = 0; i < num; ++i) { |
934 | const Part& part = partitions_[i]; |
935 | part.worker->CleanupGraphAsync( |
936 | helper->request(), helper->response(i), |
937 | [helper, i](const Status& s) { helper->call_done(i, s); }); |
938 | } |
939 | } |
940 | |
941 | void MasterSession::ReffedClientGraph::ProcessStats(int64_t step_id, |
942 | PerStepState* pss, |
943 | ProfileHandler* ph, |
944 | const RunOptions& options, |
945 | RunMetadata* resp) { |
946 | if (!pss->collect_costs && !pss->collect_timeline) return; |
947 | |
948 | // Out-of-band logging data is collected now, during post-processing. |
949 | if (pss->collect_timeline) { |
950 | SetRPCLogging(false); |
951 | RetrieveLogs(step_id, &pss->rpc_stats); |
952 | } |
953 | for (size_t i = 0; i < partitions_.size(); ++i) { |
954 | const StepStats& ss = pss->step_stats[i]; |
955 | if (ph) { |
956 | for (const auto& ds : ss.dev_stats()) { |
957 | ProcessDeviceStats(ph, ds, false /*is_rpc*/); |
958 | } |
959 | } |
960 | } |
961 | if (ph) { |
962 | for (const auto& ds : pss->rpc_stats.dev_stats()) { |
963 | ProcessDeviceStats(ph, ds, true /*is_rpc*/); |
964 | } |
965 | ph->StepDone(pss->start_micros, pss->end_micros, |
966 | Microseconds(0) /*cleanup_time*/, 0 /*total_runops*/, |
967 | OkStatus()); |
968 | } |
969 | // Assemble all stats for this timeline into a merged StepStats. |
970 | if (pss->collect_timeline) { |
971 | StepStats step_stats_proto; |
972 | step_stats_proto.Swap(&pss->rpc_stats); |
973 | for (size_t i = 0; i < partitions_.size(); ++i) { |
974 | step_stats_proto.MergeFrom(pss->step_stats[i]); |
975 | pss->step_stats[i].Clear(); |
976 | } |
977 | pss->step_stats.clear(); |
978 | // Copy the stats back, but only for on-demand profiling to avoid slowing |
979 | // down calls that trigger the automatic profiling. |
980 | if (options.trace_level() == RunOptions::FULL_TRACE) { |
981 | resp->mutable_step_stats()->Swap(&step_stats_proto); |
982 | } else { |
983 | // If FULL_TRACE, it can be fetched from Session API, no need for |
984 | // duplicated publishing. |
985 | stats_publisher_->PublishStatsProto(step_stats_proto); |
986 | } |
987 | } |
988 | } |
989 | |
990 | void MasterSession::ReffedClientGraph::ProcessDeviceStats( |
991 | ProfileHandler* ph, const DeviceStepStats& ds, bool is_rpc) { |
992 | const string& dev_name = ds.device(); |
993 | VLOG(1) << "Device " << dev_name << " reports stats for " |
994 | << ds.node_stats_size() << " nodes" ; |
995 | for (const auto& ns : ds.node_stats()) { |
996 | if (is_rpc) { |
997 | // We don't have access to a good Node pointer, so we rely on |
998 | // sufficient data being present in the NodeExecStats. |
999 | ph->RecordOneOp(dev_name, ns, true /*is_copy*/, "" , ns.node_name(), |
1000 | ns.timeline_label()); |
1001 | } else { |
1002 | auto iter = name_to_node_details_.find(ns.node_name()); |
1003 | const bool found_node_in_graph = iter != name_to_node_details_.end(); |
1004 | if (!found_node_in_graph && ns.timeline_label().empty()) { |
1005 | // The counter incrementing is not thread-safe. But we don't really |
1006 | // care. |
1007 | // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N for |
1008 | // more general usage. |
1009 | static int log_counter = 0; |
1010 | if (log_counter < 10) { |
1011 | log_counter++; |
1012 | LOG(WARNING) << "Failed to find node " << ns.node_name() |
1013 | << " for dev " << dev_name; |
1014 | } |
1015 | continue; |
1016 | } |
1017 | const string& optype = |
1018 | found_node_in_graph ? iter->second.type_string : ns.node_name(); |
1019 | string details; |
1020 | if (!ns.timeline_label().empty()) { |
1021 | details = ns.timeline_label(); |
1022 | } else if (found_node_in_graph) { |
1023 | details = DetailText(iter->second, ns); |
1024 | } else { |
1025 | // Leave details string empty |
1026 | } |
1027 | ph->RecordOneOp(dev_name, ns, false /*is_copy*/, ns.node_name(), optype, |
1028 | details); |
1029 | } |
1030 | } |
1031 | } |
1032 | |
1033 | // TODO(suharshs): Merge with CheckFetches in DirectSession. |
1034 | // TODO(suharsh,mrry): Build a map from fetch target to set of feeds it depends |
1035 | // on once at setup time to prevent us from computing the dependencies |
1036 | // everytime. |
1037 | Status MasterSession::ReffedClientGraph::CheckFetches( |
1038 | const RunStepRequestWrapper& req, const RunState* run_state, |
1039 | GraphExecutionState* execution_state) { |
1040 | // Build the set of pending feeds that we haven't seen. |
1041 | std::unordered_set<TensorId, TensorId::Hasher> pending_feeds; |
1042 | for (const auto& input : run_state->pending_inputs) { |
1043 | // Skip if already fed. |
1044 | if (input.second) continue; |
1045 | TensorId id(ParseTensorName(input.first)); |
1046 | const Node* n = execution_state->get_node_by_name(string(id.first)); |
1047 | if (n == nullptr) { |
1048 | return errors::NotFound("Feed " , input.first, ": not found" ); |
1049 | } |
1050 | pending_feeds.insert(id); |
1051 | } |
1052 | for (size_t i = 0; i < req.num_feeds(); ++i) { |
1053 | const TensorId id(ParseTensorName(req.feed_name(i))); |
1054 | pending_feeds.erase(id); |
1055 | } |
1056 | |
1057 | // Initialize the stack with the fetch nodes. |
1058 | std::vector<const Node*> stack; |
1059 | for (size_t i = 0; i < req.num_fetches(); ++i) { |
1060 | const string& fetch = req.fetch_name(i); |
1061 | const TensorId id(ParseTensorName(fetch)); |
1062 | const Node* n = execution_state->get_node_by_name(string(id.first)); |
1063 | if (n == nullptr) { |
1064 | return errors::NotFound("Fetch " , fetch, ": not found" ); |
1065 | } |
1066 | stack.push_back(n); |
1067 | } |
1068 | |
1069 | // Any tensor needed for fetches can't be in pending_feeds. |
1070 | // We need to use the original full graph from execution state. |
1071 | const Graph* graph = execution_state->full_graph(); |
1072 | std::vector<bool> visited(graph->num_node_ids(), false); |
1073 | while (!stack.empty()) { |
1074 | const Node* n = stack.back(); |
1075 | stack.pop_back(); |
1076 | |
1077 | for (const Edge* in_edge : n->in_edges()) { |
1078 | const Node* in_node = in_edge->src(); |
1079 | if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) { |
1080 | return errors::InvalidArgument("Fetch " , in_node->name(), ":" , |
1081 | in_edge->src_output(), |
1082 | " can't be computed from the feeds" |
1083 | " that have been fed so far." ); |
1084 | } |
1085 | if (!visited[in_node->id()]) { |
1086 | visited[in_node->id()] = true; |
1087 | stack.push_back(in_node); |
1088 | } |
1089 | } |
1090 | } |
1091 | return OkStatus(); |
1092 | } |
1093 | |
1094 | // Asynchronously deregisters subgraphs on the workers, without waiting for the |
1095 | // result. |
1096 | void MasterSession::ReffedClientGraph::DeregisterPartitions() { |
1097 | struct Call { |
1098 | DeregisterGraphRequest req; |
1099 | DeregisterGraphResponse resp; |
1100 | }; |
1101 | for (Part& part : partitions_) { |
1102 | // The graph handle may be empty if we failed during partition registration. |
1103 | if (!part.graph_handle.empty()) { |
1104 | Call* c = new Call; |
1105 | c->req.set_session_handle(session_handle_); |
1106 | c->req.set_create_worker_session_called(!should_deregister_); |
1107 | c->req.set_graph_handle(part.graph_handle); |
1108 | // NOTE(mrry): We must capture `worker_cache_` since `this` |
1109 | // could be deleted before the callback is called. |
1110 | WorkerCacheInterface* worker_cache = worker_cache_; |
1111 | const string name = part.name; |
1112 | WorkerInterface* w = part.worker; |
1113 | CHECK_NOTNULL(w); |
1114 | auto cb = [worker_cache, c, name, w](const Status& s) { |
1115 | if (!s.ok()) { |
1116 | // This error is potentially benign, so we don't log at the |
1117 | // error level. |
1118 | LOG(INFO) << "DeregisterGraph error: " << s; |
1119 | } |
1120 | delete c; |
1121 | worker_cache->ReleaseWorker(name, w); |
1122 | }; |
1123 | w->DeregisterGraphAsync(&c->req, &c->resp, cb); |
1124 | } |
1125 | } |
1126 | } |
1127 | |
1128 | namespace { |
1129 | void CopyAndSortStrings(size_t size, |
1130 | const std::function<string(size_t)>& input_accessor, |
1131 | protobuf::RepeatedPtrField<string>* output) { |
1132 | std::vector<string> temp; |
1133 | temp.reserve(size); |
1134 | for (size_t i = 0; i < size; ++i) { |
1135 | output->Add(input_accessor(i)); |
1136 | } |
1137 | std::sort(output->begin(), output->end()); |
1138 | } |
1139 | } // namespace |
1140 | |
1141 | void BuildBuildGraphOptions(const RunStepRequestWrapper& req, |
1142 | const ConfigProto& config, |
1143 | BuildGraphOptions* opts) { |
1144 | CallableOptions* callable_opts = &opts->callable_options; |
1145 | CopyAndSortStrings( |
1146 | req.num_feeds(), [&req](size_t i) { return req.feed_name(i); }, |
1147 | callable_opts->mutable_feed()); |
1148 | CopyAndSortStrings( |
1149 | req.num_fetches(), [&req](size_t i) { return req.fetch_name(i); }, |
1150 | callable_opts->mutable_fetch()); |
1151 | CopyAndSortStrings( |
1152 | req.num_targets(), [&req](size_t i) { return req.target_name(i); }, |
1153 | callable_opts->mutable_target()); |
1154 | |
1155 | if (!req.options().debug_options().debug_tensor_watch_opts().empty()) { |
1156 | *callable_opts->mutable_run_options()->mutable_debug_options() = |
1157 | req.options().debug_options(); |
1158 | } |
1159 | |
1160 | opts->collective_graph_key = |
1161 | req.options().experimental().collective_graph_key(); |
1162 | if (config.experimental().collective_deterministic_sequential_execution()) { |
1163 | opts->collective_order = GraphCollectiveOrder::kEdges; |
1164 | } else if (config.experimental().collective_nccl()) { |
1165 | opts->collective_order = GraphCollectiveOrder::kAttrs; |
1166 | } |
1167 | } |
1168 | |
1169 | void BuildBuildGraphOptions(const PartialRunSetupRequest& req, |
1170 | BuildGraphOptions* opts) { |
1171 | CallableOptions* callable_opts = &opts->callable_options; |
1172 | CopyAndSortStrings( |
1173 | req.feed_size(), [&req](size_t i) { return req.feed(i); }, |
1174 | callable_opts->mutable_feed()); |
1175 | CopyAndSortStrings( |
1176 | req.fetch_size(), [&req](size_t i) { return req.fetch(i); }, |
1177 | callable_opts->mutable_fetch()); |
1178 | CopyAndSortStrings( |
1179 | req.target_size(), [&req](size_t i) { return req.target(i); }, |
1180 | callable_opts->mutable_target()); |
1181 | |
1182 | // TODO(cais): Add TFDBG support to partial runs. |
1183 | } |
1184 | |
1185 | uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) { |
1186 | uint64 h = 0x2b992ddfa23249d6ull; |
1187 | for (const string& name : opts.callable_options.feed()) { |
1188 | h = Hash64(name.c_str(), name.size(), h); |
1189 | } |
1190 | for (const string& name : opts.callable_options.target()) { |
1191 | h = Hash64(name.c_str(), name.size(), h); |
1192 | } |
1193 | for (const string& name : opts.callable_options.fetch()) { |
1194 | h = Hash64(name.c_str(), name.size(), h); |
1195 | } |
1196 | |
1197 | const DebugOptions& debug_options = |
1198 | opts.callable_options.run_options().debug_options(); |
1199 | if (!debug_options.debug_tensor_watch_opts().empty()) { |
1200 | const string watch_summary = |
1201 | SummarizeDebugTensorWatches(debug_options.debug_tensor_watch_opts()); |
1202 | h = Hash64(watch_summary.c_str(), watch_summary.size(), h); |
1203 | } |
1204 | |
1205 | return h; |
1206 | } |
1207 | |
1208 | string BuildGraphOptionsString(const BuildGraphOptions& opts) { |
1209 | string buf; |
1210 | for (const string& name : opts.callable_options.feed()) { |
1211 | strings::StrAppend(&buf, " FdE: " , name); |
1212 | } |
1213 | strings::StrAppend(&buf, "\n" ); |
1214 | for (const string& name : opts.callable_options.target()) { |
1215 | strings::StrAppend(&buf, " TN: " , name); |
1216 | } |
1217 | strings::StrAppend(&buf, "\n" ); |
1218 | for (const string& name : opts.callable_options.fetch()) { |
1219 | strings::StrAppend(&buf, " FeE: " , name); |
1220 | } |
1221 | if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) { |
1222 | strings::StrAppend(&buf, "\nGK: " , opts.collective_graph_key); |
1223 | } |
1224 | strings::StrAppend(&buf, "\n" ); |
1225 | return buf; |
1226 | } |
1227 | |
1228 | MasterSession::MasterSession( |
1229 | const SessionOptions& opt, const MasterEnv* env, |
1230 | std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs, |
1231 | std::unique_ptr<WorkerCacheInterface> worker_cache, |
1232 | std::unique_ptr<DeviceSet> device_set, |
1233 | std::vector<string> filtered_worker_list, |
1234 | StatsPublisherFactory stats_publisher_factory) |
1235 | : session_opts_(opt), |
1236 | env_(env), |
1237 | handle_(strings::FpToString(random::New64())), |
1238 | remote_devs_(std::move(remote_devs)), |
1239 | worker_cache_(std::move(worker_cache)), |
1240 | devices_(std::move(device_set)), |
1241 | filtered_worker_list_(std::move(filtered_worker_list)), |
1242 | stats_publisher_factory_(std::move(stats_publisher_factory)), |
1243 | graph_version_(0), |
1244 | run_graphs_(5), |
1245 | partial_run_graphs_(5) { |
1246 | UpdateLastAccessTime(); |
1247 | CHECK(devices_) << "device_set was null!" ; |
1248 | |
1249 | VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size() |
1250 | << " #remote " << remote_devs_->size(); |
1251 | VLOG(1) << "Start master session " << handle_ |
1252 | << " with config: " << session_opts_.config.ShortDebugString(); |
1253 | } |
1254 | |
1255 | MasterSession::~MasterSession() { |
1256 | for (const auto& iter : run_graphs_) iter.second->Unref(); |
1257 | for (const auto& iter : partial_run_graphs_) iter.second->Unref(); |
1258 | } |
1259 | |
1260 | void MasterSession::UpdateLastAccessTime() { |
1261 | last_access_time_usec_.store(Env::Default()->NowMicros()); |
1262 | } |
1263 | |
1264 | Status MasterSession::Create(GraphDef&& graph_def, |
1265 | const ClusterDef& cluster_def) { |
1266 | if (session_opts_.config.use_per_session_threads() || |
1267 | session_opts_.config.session_inter_op_thread_pool_size() > 0) { |
1268 | return errors::InvalidArgument( |
1269 | "Distributed session does not support session thread pool options." ); |
1270 | } |
1271 | if (session_opts_.config.graph_options().place_pruned_graph()) { |
1272 | // TODO(b/29900832): Fix this or remove the option. |
1273 | LOG(WARNING) << "Distributed session does not support the " |
1274 | "place_pruned_graph option." ; |
1275 | session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false); |
1276 | } |
1277 | |
1278 | GraphExecutionStateOptions execution_options; |
1279 | execution_options.device_set = devices_.get(); |
1280 | execution_options.session_options = &session_opts_; |
1281 | { |
1282 | mutex_lock l(mu_); |
1283 | TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph( |
1284 | std::move(graph_def), execution_options, &execution_state_)); |
1285 | } |
1286 | should_delete_worker_sessions_ = true; |
1287 | return CreateWorkerSessions(cluster_def); |
1288 | } |
1289 | |
1290 | Status MasterSession::CreateWorkerSessions(const ClusterDef& cluster_def) { |
1291 | const std::vector<string> worker_names = filtered_worker_list_; |
1292 | WorkerCacheInterface* worker_cache = get_worker_cache(); |
1293 | |
1294 | struct WorkerGroup { |
1295 | // The worker name. (Not owned.) |
1296 | const string* name; |
1297 | |
1298 | // The worker referenced by name. (Not owned.) |
1299 | WorkerInterface* worker = nullptr; |
1300 | |
1301 | // Request and responses used for a given worker. |
1302 | CreateWorkerSessionRequest request; |
1303 | CreateWorkerSessionResponse response; |
1304 | Status status = OkStatus(); |
1305 | }; |
1306 | BlockingCounter done(worker_names.size()); |
1307 | std::vector<WorkerGroup> workers(worker_names.size()); |
1308 | |
1309 | // Release the workers. |
1310 | auto cleanup = gtl::MakeCleanup([&workers, worker_cache] { |
1311 | for (auto&& worker_group : workers) { |
1312 | if (worker_group.worker != nullptr) { |
1313 | worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker); |
1314 | } |
1315 | } |
1316 | }); |
1317 | |
1318 | string task_name; |
1319 | string local_device_name; |
1320 | DeviceNameUtils::SplitDeviceName(devices_->client_device()->name(), |
1321 | &task_name, &local_device_name); |
1322 | const int64_t client_device_incarnation = |
1323 | devices_->client_device()->attributes().incarnation(); |
1324 | |
1325 | Status status = OkStatus(); |
1326 | // Create all the workers & kick off the computations. |
1327 | for (size_t i = 0; i < worker_names.size(); ++i) { |
1328 | workers[i].name = &worker_names[i]; |
1329 | workers[i].worker = worker_cache->GetOrCreateWorker(worker_names[i]); |
1330 | workers[i].request.set_session_handle(handle_); |
1331 | workers[i].request.set_master_task(task_name); |
1332 | workers[i].request.set_master_incarnation(client_device_incarnation); |
1333 | if (session_opts_.config.share_cluster_devices_in_session() || |
1334 | session_opts_.config.experimental() |
1335 | .share_cluster_devices_in_session()) { |
1336 | for (const auto& remote_dev : devices_->devices()) { |
1337 | *workers[i].request.add_cluster_device_attributes() = |
1338 | remote_dev->attributes(); |
1339 | } |
1340 | |
1341 | if (!session_opts_.config.share_cluster_devices_in_session() && |
1342 | session_opts_.config.experimental() |
1343 | .share_cluster_devices_in_session()) { |
1344 | LOG(WARNING) |
1345 | << "ConfigProto.Experimental.share_cluster_devices_in_session has " |
1346 | "been promoted to a non-experimental API. Please use " |
1347 | "ConfigProto.share_cluster_devices_in_session instead. The " |
1348 | "experimental option will be removed in the future." ; |
1349 | } |
1350 | } |
1351 | |
1352 | DeviceNameUtils::ParsedName name; |
1353 | if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) { |
1354 | status = errors::Internal("Could not parse name " , worker_names[i]); |
1355 | LOG(WARNING) << status; |
1356 | return status; |
1357 | } |
1358 | if (!name.has_job || !name.has_task) { |
1359 | status = errors::Internal("Incomplete worker name " , worker_names[i]); |
1360 | LOG(WARNING) << status; |
1361 | return status; |
1362 | } |
1363 | |
1364 | workers[i].request.mutable_server_def()->set_protocol("grpc" ); |
1365 | workers[i].request.mutable_server_def()->set_job_name(name.job); |
1366 | workers[i].request.mutable_server_def()->set_task_index(name.task); |
1367 | if (!cluster_def.job().empty()) { |
1368 | *workers[i].request.mutable_server_def()->mutable_cluster() = cluster_def; |
1369 | // Session state is always isolated when ClusterSpec propagation |
1370 | // is in use. |
1371 | workers[i].request.set_isolate_session_state(true); |
1372 | } else { |
1373 | // NOTE(mrry): Do not set any component of the ServerDef, |
1374 | // because the worker will use its local configuration. |
1375 | workers[i].request.set_isolate_session_state( |
1376 | session_opts_.config.isolate_session_state()); |
1377 | } |
1378 | CoordinationServiceConfig coordination_config; |
1379 | // Enable coordination service in session options by default if |
1380 | // unspecified in non-local targets. |
1381 | if (session_opts_.target != "local" && |
1382 | !session_opts_.config.experimental().has_coordination_config()) { |
1383 | coordination_config.set_service_type("standalone" ); |
1384 | } else { |
1385 | coordination_config = |
1386 | session_opts_.config.experimental().coordination_config(); |
1387 | } |
1388 | // Specify master task as coordination service leader. |
1389 | coordination_config.set_service_leader(task_name); |
1390 | *workers[i] |
1391 | .request.mutable_server_def() |
1392 | ->mutable_default_session_config() |
1393 | ->mutable_experimental() |
1394 | ->mutable_coordination_config() = coordination_config; |
1395 | |
1396 | if (session_opts_.config.experimental() |
1397 | .share_session_state_in_clusterspec_propagation()) { |
1398 | // In a dynamic cluster, the ClusterSpec info is usually propagated by |
1399 | // master sessions. However, in data parallel training with multiple |
1400 | // masters |
1401 | // ("between-graph replication"), we need to disable isolation for |
1402 | // different worker sessions to update the same variables in PS tasks. |
1403 | workers[i].request.set_isolate_session_state(false); |
1404 | } |
1405 | } |
1406 | |
1407 | for (size_t i = 0; i < worker_names.size(); ++i) { |
1408 | auto cb = [i, &workers, &done](const Status& s) { |
1409 | workers[i].status = s; |
1410 | done.DecrementCount(); |
1411 | }; |
1412 | workers[i].worker->CreateWorkerSessionAsync(&workers[i].request, |
1413 | &workers[i].response, cb); |
1414 | } |
1415 | |
1416 | done.Wait(); |
1417 | for (size_t i = 0; i < workers.size(); ++i) { |
1418 | status.Update(workers[i].status); |
1419 | } |
1420 | return status; |
1421 | } |
1422 | |
1423 | Status MasterSession::DeleteWorkerSessions() { |
1424 | WorkerCacheInterface* worker_cache = get_worker_cache(); |
1425 | const std::vector<string>& worker_names = filtered_worker_list_; |
1426 | |
1427 | struct WorkerGroup { |
1428 | // The worker name. (Not owned.) |
1429 | const string* name; |
1430 | |
1431 | // The worker referenced by name. (Not owned.) |
1432 | WorkerInterface* worker = nullptr; |
1433 | |
1434 | CallOptions call_opts; |
1435 | |
1436 | // Request and responses used for a given worker. |
1437 | DeleteWorkerSessionRequest request; |
1438 | DeleteWorkerSessionResponse response; |
1439 | Status status = OkStatus(); |
1440 | }; |
1441 | BlockingCounter done(worker_names.size()); |
1442 | std::vector<WorkerGroup> workers(worker_names.size()); |
1443 | |
1444 | // Release the workers. |
1445 | auto cleanup = gtl::MakeCleanup([&workers, worker_cache] { |
1446 | for (auto&& worker_group : workers) { |
1447 | if (worker_group.worker != nullptr) { |
1448 | worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker); |
1449 | } |
1450 | } |
1451 | }); |
1452 | |
1453 | Status status = OkStatus(); |
1454 | // Create all the workers & kick off the computations. |
1455 | for (size_t i = 0; i < worker_names.size(); ++i) { |
1456 | workers[i].name = &worker_names[i]; |
1457 | workers[i].worker = worker_cache->GetOrCreateWorker(worker_names[i]); |
1458 | workers[i].request.set_session_handle(handle_); |
1459 | // Since the worker may have gone away, set a timeout to avoid blocking the |
1460 | // session-close operation. |
1461 | workers[i].call_opts.SetTimeout(10000); |
1462 | } |
1463 | |
1464 | for (size_t i = 0; i < worker_names.size(); ++i) { |
1465 | auto cb = [i, &workers, &done](const Status& s) { |
1466 | workers[i].status = s; |
1467 | done.DecrementCount(); |
1468 | }; |
1469 | workers[i].worker->DeleteWorkerSessionAsync( |
1470 | &workers[i].call_opts, &workers[i].request, &workers[i].response, cb); |
1471 | } |
1472 | |
1473 | done.Wait(); |
1474 | for (size_t i = 0; i < workers.size(); ++i) { |
1475 | status.Update(workers[i].status); |
1476 | } |
1477 | return status; |
1478 | } |
1479 | |
1480 | Status MasterSession::ListDevices(ListDevicesResponse* resp) const { |
1481 | if (worker_cache_) { |
1482 | // This is a ClusterSpec-propagated session, and thus env_->local_devices |
1483 | // are invalid. |
1484 | |
1485 | // Mark the "client_device" as the sole local device. |
1486 | const Device* client_device = devices_->client_device(); |
1487 | for (const Device* dev : devices_->devices()) { |
1488 | if (dev != client_device) { |
1489 | *(resp->add_remote_device()) = dev->attributes(); |
1490 | } |
1491 | } |
1492 | *(resp->add_local_device()) = client_device->attributes(); |
1493 | } else { |
1494 | for (Device* dev : env_->local_devices) { |
1495 | *(resp->add_local_device()) = dev->attributes(); |
1496 | } |
1497 | for (auto&& dev : *remote_devs_) { |
1498 | *(resp->add_local_device()) = dev->attributes(); |
1499 | } |
1500 | } |
1501 | return OkStatus(); |
1502 | } |
1503 | |
1504 | Status MasterSession::Extend(const ExtendSessionRequest* req, |
1505 | ExtendSessionResponse* resp) { |
1506 | UpdateLastAccessTime(); |
1507 | std::unique_ptr<GraphExecutionState> extended_execution_state; |
1508 | { |
1509 | mutex_lock l(mu_); |
1510 | if (closed_) { |
1511 | return errors::FailedPrecondition("Session is closed." ); |
1512 | } |
1513 | |
1514 | if (graph_version_ != req->current_graph_version()) { |
1515 | return errors::Aborted("Current version is " , graph_version_, |
1516 | " but caller expected " , |
1517 | req->current_graph_version(), "." ); |
1518 | } |
1519 | |
1520 | CHECK(execution_state_); |
1521 | TF_RETURN_IF_ERROR( |
1522 | execution_state_->Extend(req->graph_def(), &extended_execution_state)); |
1523 | |
1524 | CHECK(extended_execution_state); |
1525 | // The old execution state will be released outside the lock. |
1526 | execution_state_.swap(extended_execution_state); |
1527 | ++graph_version_; |
1528 | resp->set_new_graph_version(graph_version_); |
1529 | } |
1530 | return OkStatus(); |
1531 | } |
1532 | |
1533 | WorkerCacheInterface* MasterSession::get_worker_cache() const { |
1534 | if (worker_cache_) { |
1535 | return worker_cache_.get(); |
1536 | } |
1537 | return env_->worker_cache; |
1538 | } |
1539 | |
1540 | Status MasterSession::StartStep(const BuildGraphOptions& opts, bool is_partial, |
1541 | ReffedClientGraph** out_rcg, |
1542 | int64_t* out_count) { |
1543 | const uint64 hash = HashBuildGraphOptions(opts); |
1544 | { |
1545 | mutex_lock l(mu_); |
1546 | // TODO(suharshs): We cache partial run graphs and run graphs separately |
1547 | // because there is preprocessing that needs to only be run for partial |
1548 | // run calls. |
1549 | RCGMap* m = is_partial ? &partial_run_graphs_ : &run_graphs_; |
1550 | auto iter = m->find(hash); |
1551 | if (iter == m->end()) { |
1552 | // We have not seen this subgraph before. Build the subgraph and |
1553 | // cache it. |
1554 | VLOG(1) << "Unseen hash " << hash << " for " |
1555 | << BuildGraphOptionsString(opts) << " is_partial = " << is_partial |
1556 | << "\n" ; |
1557 | std::unique_ptr<ClientGraph> client_graph; |
1558 | TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph)); |
1559 | WorkerCacheInterface* worker_cache = get_worker_cache(); |
1560 | auto entry = new ReffedClientGraph( |
1561 | handle_, opts, std::move(client_graph), session_opts_, |
1562 | stats_publisher_factory_, is_partial, worker_cache, |
1563 | !should_delete_worker_sessions_); |
1564 | iter = m->insert({hash, entry}).first; |
1565 | VLOG(1) << "Preparing to execute new graph" ; |
1566 | } |
1567 | *out_rcg = iter->second; |
1568 | (*out_rcg)->Ref(); |
1569 | *out_count = (*out_rcg)->get_and_increment_execution_count(); |
1570 | } |
1571 | return OkStatus(); |
1572 | } |
1573 | |
1574 | void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref, |
1575 | RCGMap* rcg_map) { |
1576 | VLOG(1) << "Discarding all reffed graphs" ; |
1577 | for (auto p : *rcg_map) { |
1578 | ReffedClientGraph* rcg = p.second; |
1579 | if (to_unref) { |
1580 | to_unref->push_back(rcg); |
1581 | } else { |
1582 | rcg->Unref(); |
1583 | } |
1584 | } |
1585 | rcg_map->clear(); |
1586 | } |
1587 | |
1588 | uint64 MasterSession::NewStepId(int64_t graph_key) { |
1589 | if (graph_key == BuildGraphOptions::kNoCollectiveGraphKey) { |
1590 | // StepId must leave the most-significant 7 bits empty for future use. |
1591 | return random::New64() & (((1uLL << 56) - 1) | (1uLL << 56)); |
1592 | } else { |
1593 | uint64 step_id = env_->collective_executor_mgr->NextStepId(graph_key); |
1594 | int32_t retry_count = 0; |
1595 | while (static_cast<int64_t>(step_id) == CollectiveExecutor::kInvalidId) { |
1596 | Notification note; |
1597 | Status status; |
1598 | env_->collective_executor_mgr->RefreshStepIdSequenceAsync( |
1599 | graph_key, [&status, ¬e](const Status& s) { |
1600 | status = s; |
1601 | note.Notify(); |
1602 | }); |
1603 | note.WaitForNotification(); |
1604 | if (!status.ok()) { |
1605 | LOG(ERROR) << "Bad status from " |
1606 | "collective_executor_mgr->RefreshStepIdSequence: " |
1607 | << status << ". Retrying." ; |
1608 | int64_t delay_micros = std::min(60000000LL, 1000000LL * ++retry_count); |
1609 | Env::Default()->SleepForMicroseconds(delay_micros); |
1610 | } else { |
1611 | step_id = env_->collective_executor_mgr->NextStepId(graph_key); |
1612 | } |
1613 | } |
1614 | return step_id; |
1615 | } |
1616 | } |
1617 | |
1618 | Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req, |
1619 | PartialRunSetupResponse* resp) { |
1620 | std::vector<string> inputs, outputs, targets; |
1621 | for (const auto& feed : req->feed()) { |
1622 | inputs.push_back(feed); |
1623 | } |
1624 | for (const auto& fetch : req->fetch()) { |
1625 | outputs.push_back(fetch); |
1626 | } |
1627 | for (const auto& target : req->target()) { |
1628 | targets.push_back(target); |
1629 | } |
1630 | |
1631 | string handle = std::to_string(partial_run_handle_counter_.fetch_add(1)); |
1632 | |
1633 | ReffedClientGraph* rcg = nullptr; |
1634 | |
1635 | // Prepare. |
1636 | BuildGraphOptions opts; |
1637 | BuildBuildGraphOptions(*req, &opts); |
1638 | int64_t count = 0; |
1639 | TF_RETURN_IF_ERROR(StartStep(opts, true, &rcg, &count)); |
1640 | |
1641 | rcg->Ref(); |
1642 | RunState* run_state = |
1643 | new RunState(inputs, outputs, rcg, |
1644 | NewStepId(BuildGraphOptions::kNoCollectiveGraphKey), count); |
1645 | { |
1646 | mutex_lock l(mu_); |
1647 | partial_runs_.emplace( |
1648 | std::make_pair(handle, std::unique_ptr<RunState>(run_state))); |
1649 | } |
1650 | |
1651 | TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg)); |
1652 | |
1653 | resp->set_partial_run_handle(handle); |
1654 | return OkStatus(); |
1655 | } |
1656 | |
1657 | Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req, |
1658 | MutableRunStepResponseWrapper* resp) { |
1659 | UpdateLastAccessTime(); |
1660 | { |
1661 | mutex_lock l(mu_); |
1662 | if (closed_) { |
1663 | return errors::FailedPrecondition("Session is closed." ); |
1664 | } |
1665 | ++num_running_; |
1666 | // Note: all code paths must eventually call MarkRunCompletion() |
1667 | // in order to appropriate decrement the num_running_ counter. |
1668 | } |
1669 | Status status; |
1670 | if (!req.partial_run_handle().empty()) { |
1671 | status = DoPartialRun(opts, req, resp); |
1672 | } else { |
1673 | status = DoRunWithLocalExecution(opts, req, resp); |
1674 | } |
1675 | return status; |
1676 | } |
1677 | |
1678 | // Decrements num_running_ and broadcasts if num_running_ is zero. |
1679 | void MasterSession::MarkRunCompletion() { |
1680 | mutex_lock l(mu_); |
1681 | --num_running_; |
1682 | if (num_running_ == 0) { |
1683 | num_running_is_zero_.notify_all(); |
1684 | } |
1685 | } |
1686 | |
1687 | Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { |
1688 | // Registers subgraphs if haven't done so. |
1689 | PartitionOptions popts; |
1690 | popts.node_to_loc = SplitByWorker; |
1691 | // The closures popts.{new_name,get_incarnation} are called synchronously in |
1692 | // RegisterPartitions() below, so do not need a Ref()/Unref() pair to keep |
1693 | // "this" alive during the closure. |
1694 | popts.new_name = [this](const string& prefix) { |
1695 | mutex_lock l(mu_); |
1696 | return strings::StrCat(prefix, "_S" , next_node_id_++); |
1697 | }; |
1698 | popts.get_incarnation = [this](const string& name) -> int64 { |
1699 | Device* d = devices_->FindDeviceByName(name); |
1700 | if (d == nullptr) { |
1701 | return PartitionOptions::kIllegalIncarnation; |
1702 | } else { |
1703 | return d->attributes().incarnation(); |
1704 | } |
1705 | }; |
1706 | popts.control_flow_added = false; |
1707 | const bool enable_bfloat16_sendrecv = |
1708 | session_opts_.config.graph_options().enable_bfloat16_sendrecv(); |
1709 | popts.should_cast = [enable_bfloat16_sendrecv](const Edge* e) { |
1710 | if (e->IsControlEdge()) { |
1711 | return DT_FLOAT; |
1712 | } |
1713 | DataType dtype = BaseType(e->src()->output_type(e->src_output())); |
1714 | if (enable_bfloat16_sendrecv && dtype == DT_FLOAT) { |
1715 | return DT_BFLOAT16; |
1716 | } else { |
1717 | return dtype; |
1718 | } |
1719 | }; |
1720 | if (session_opts_.config.graph_options().enable_recv_scheduling()) { |
1721 | popts.scheduling_for_recvs = true; |
1722 | popts.need_to_record_start_times = true; |
1723 | } |
1724 | |
1725 | TF_RETURN_IF_ERROR(rcg->RegisterPartitions(std::move(popts))); |
1726 | |
1727 | return OkStatus(); |
1728 | } |
1729 | |
1730 | Status MasterSession::DoPartialRun(CallOptions* opts, |
1731 | const RunStepRequestWrapper& req, |
1732 | MutableRunStepResponseWrapper* resp) { |
1733 | auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); }); |
1734 | const string& prun_handle = req.partial_run_handle(); |
1735 | RunState* run_state = nullptr; |
1736 | { |
1737 | mutex_lock l(mu_); |
1738 | auto it = partial_runs_.find(prun_handle); |
1739 | if (it == partial_runs_.end()) { |
1740 | return errors::InvalidArgument( |
1741 | "Must run PartialRunSetup before performing partial runs" ); |
1742 | } |
1743 | run_state = it->second.get(); |
1744 | } |
1745 | // CollectiveOps are not supported in partial runs. |
1746 | if (req.options().experimental().collective_graph_key() != |
1747 | BuildGraphOptions::kNoCollectiveGraphKey) { |
1748 | return errors::InvalidArgument( |
1749 | "PartialRun does not support Collective ops. collective_graph_key " |
1750 | "must be kNoCollectiveGraphKey." ); |
1751 | } |
1752 | |
1753 | // If this is the first partial run, initialize the PerStepState. |
1754 | if (!run_state->step_started) { |
1755 | run_state->step_started = true; |
1756 | PerStepState pss; |
1757 | |
1758 | const auto count = run_state->count; |
1759 | pss.collect_timeline = |
1760 | req.options().trace_level() == RunOptions::FULL_TRACE; |
1761 | pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE; |
1762 | pss.report_tensor_allocations_upon_oom = |
1763 | req.options().report_tensor_allocations_upon_oom(); |
1764 | |
1765 | // Build the cost model every 'build_cost_model_every' steps after skipping |
1766 | // an |
1767 | // initial 'build_cost_model_after' steps. |
1768 | const int64_t build_cost_model_after = |
1769 | session_opts_.config.graph_options().build_cost_model_after(); |
1770 | const int64_t build_cost_model_every = |
1771 | session_opts_.config.graph_options().build_cost_model(); |
1772 | pss.collect_costs = |
1773 | build_cost_model_every > 0 && |
1774 | ((count + 1 - build_cost_model_after) % build_cost_model_every == 0); |
1775 | pss.collect_partition_graphs = req.options().output_partition_graphs(); |
1776 | |
1777 | std::unique_ptr<ProfileHandler> ph = run_state->rcg->GetProfileHandler( |
1778 | run_state->step_id, count, req.options()); |
1779 | if (ph) { |
1780 | pss.collect_timeline = true; |
1781 | pss.collect_rpcs = ph->should_collect_rpcs(); |
1782 | } |
1783 | |
1784 | run_state->pss = std::move(pss); |
1785 | run_state->ph = std::move(ph); |
1786 | } |
1787 | |
1788 | // Make sure that this is a new set of feeds that are still pending. |
1789 | for (size_t i = 0; i < req.num_feeds(); ++i) { |
1790 | const string& feed = req.feed_name(i); |
1791 | auto it = run_state->pending_inputs.find(feed); |
1792 | if (it == run_state->pending_inputs.end()) { |
1793 | return errors::InvalidArgument( |
1794 | "The feed " , feed, " was not specified in partial_run_setup." ); |
1795 | } else if (it->second) { |
1796 | return errors::InvalidArgument("The feed " , feed, |
1797 | " has already been fed." ); |
1798 | } |
1799 | } |
1800 | // Check that this is a new set of fetches that are still pending. |
1801 | for (size_t i = 0; i < req.num_fetches(); ++i) { |
1802 | const string& fetch = req.fetch_name(i); |
1803 | auto it = run_state->pending_outputs.find(fetch); |
1804 | if (it == run_state->pending_outputs.end()) { |
1805 | return errors::InvalidArgument( |
1806 | "The fetch " , fetch, " was not specified in partial_run_setup." ); |
1807 | } else if (it->second) { |
1808 | return errors::InvalidArgument("The fetch " , fetch, |
1809 | " has already been fetched." ); |
1810 | } |
1811 | } |
1812 | |
1813 | // Ensure that the requested fetches can be computed from the provided feeds. |
1814 | { |
1815 | mutex_lock l(mu_); |
1816 | TF_RETURN_IF_ERROR( |
1817 | run_state->rcg->CheckFetches(req, run_state, execution_state_.get())); |
1818 | } |
1819 | |
1820 | // Determine if this partial run satisfies all the pending inputs and outputs. |
1821 | for (size_t i = 0; i < req.num_feeds(); ++i) { |
1822 | auto it = run_state->pending_inputs.find(req.feed_name(i)); |
1823 | it->second = true; |
1824 | } |
1825 | for (size_t i = 0; i < req.num_fetches(); ++i) { |
1826 | auto it = run_state->pending_outputs.find(req.fetch_name(i)); |
1827 | it->second = true; |
1828 | } |
1829 | bool is_last_partial_run = run_state->PendingDone(); |
1830 | |
1831 | Status s = run_state->rcg->RunPartitions( |
1832 | env_, run_state->step_id, run_state->count, &run_state->pss, opts, req, |
1833 | resp, &cancellation_manager_, is_last_partial_run); |
1834 | |
1835 | // Delete the run state if there is an error or all fetches are done. |
1836 | if (!s.ok() || is_last_partial_run) { |
1837 | ReffedClientGraph* rcg = run_state->rcg; |
1838 | run_state->pss.end_micros = Env::Default()->NowMicros(); |
1839 | // Schedule post-processing and cleanup to be done asynchronously. |
1840 | Ref(); |
1841 | rcg->Ref(); |
1842 | rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(), |
1843 | req.options(), resp->mutable_metadata()); |
1844 | cleanup.release(); // MarkRunCompletion called in done closure. |
1845 | rcg->CleanupPartitionsAsync( |
1846 | run_state->step_id, [this, rcg, prun_handle](const Status& s) { |
1847 | if (!s.ok()) { |
1848 | LOG(ERROR) << "Cleanup partition error: " << s; |
1849 | } |
1850 | rcg->Unref(); |
1851 | MarkRunCompletion(); |
1852 | Unref(); |
1853 | }); |
1854 | mutex_lock l(mu_); |
1855 | partial_runs_.erase(prun_handle); |
1856 | } |
1857 | return s; |
1858 | } |
1859 | |
1860 | Status MasterSession::CreateDebuggerState( |
1861 | const DebugOptions& debug_options, const RunStepRequestWrapper& req, |
1862 | int64_t rcg_execution_count, |
1863 | std::unique_ptr<DebuggerStateInterface>* debugger_state) { |
1864 | TF_RETURN_IF_ERROR( |
1865 | DebuggerStateRegistry::CreateState(debug_options, debugger_state)); |
1866 | |
1867 | std::vector<string> input_names; |
1868 | for (size_t i = 0; i < req.num_feeds(); ++i) { |
1869 | input_names.push_back(req.feed_name(i)); |
1870 | } |
1871 | std::vector<string> output_names; |
1872 | for (size_t i = 0; i < req.num_fetches(); ++i) { |
1873 | output_names.push_back(req.fetch_name(i)); |
1874 | } |
1875 | std::vector<string> target_names; |
1876 | for (size_t i = 0; i < req.num_targets(); ++i) { |
1877 | target_names.push_back(req.target_name(i)); |
1878 | } |
1879 | |
1880 | // TODO(cais): We currently use -1 as a dummy value for session run count. |
1881 | // While this counter value is straightforward to define and obtain for |
1882 | // DirectSessions, it is less so for non-direct Sessions. Devise a better |
1883 | // way to get its value when the need arises. |
1884 | TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata( |
1885 | debug_options.global_step(), rcg_execution_count, rcg_execution_count, |
1886 | input_names, output_names, target_names)); |
1887 | |
1888 | return OkStatus(); |
1889 | } |
1890 | |
1891 | void MasterSession::FillPerStepState(MasterSession::ReffedClientGraph* rcg, |
1892 | const RunOptions& run_options, |
1893 | uint64 step_id, int64_t count, |
1894 | PerStepState* out_pss, |
1895 | std::unique_ptr<ProfileHandler>* out_ph) { |
1896 | out_pss->collect_timeline = |
1897 | run_options.trace_level() == RunOptions::FULL_TRACE; |
1898 | out_pss->collect_rpcs = run_options.trace_level() == RunOptions::FULL_TRACE; |
1899 | out_pss->report_tensor_allocations_upon_oom = |
1900 | run_options.report_tensor_allocations_upon_oom(); |
1901 | // Build the cost model every 'build_cost_model_every' steps after skipping an |
1902 | // initial 'build_cost_model_after' steps. |
1903 | const int64_t build_cost_model_after = |
1904 | session_opts_.config.graph_options().build_cost_model_after(); |
1905 | const int64_t build_cost_model_every = |
1906 | session_opts_.config.graph_options().build_cost_model(); |
1907 | out_pss->collect_costs = |
1908 | build_cost_model_every > 0 && |
1909 | ((count + 1 - build_cost_model_after) % build_cost_model_every == 0); |
1910 | out_pss->collect_partition_graphs = run_options.output_partition_graphs(); |
1911 | |
1912 | *out_ph = rcg->GetProfileHandler(step_id, count, run_options); |
1913 | if (*out_ph) { |
1914 | out_pss->collect_timeline = true; |
1915 | out_pss->collect_rpcs = (*out_ph)->should_collect_rpcs(); |
1916 | } |
1917 | } |
1918 | |
1919 | Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg, |
1920 | uint64 step_id, |
1921 | const RunOptions& run_options, |
1922 | PerStepState* pss, |
1923 | const std::unique_ptr<ProfileHandler>& ph, |
1924 | const Status& run_status, |
1925 | RunMetadata* out_run_metadata) { |
1926 | Status s = run_status; |
1927 | if (s.ok()) { |
1928 | pss->end_micros = Env::Default()->NowMicros(); |
1929 | if (rcg->collective_graph_key() != |
1930 | BuildGraphOptions::kNoCollectiveGraphKey) { |
1931 | env_->collective_executor_mgr->RetireStepId(rcg->collective_graph_key(), |
1932 | step_id); |
1933 | } |
1934 | // Schedule post-processing and cleanup to be done asynchronously. |
1935 | rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata); |
1936 | } else if (errors::IsCancelled(s)) { |
1937 | mutex_lock l(mu_); |
1938 | if (closed_) { |
1939 | if (garbage_collected_) { |
1940 | s = errors::Cancelled( |
1941 | "Step was cancelled because the session was garbage collected due " |
1942 | "to inactivity." ); |
1943 | } else { |
1944 | s = errors::Cancelled( |
1945 | "Step was cancelled by an explicit call to `Session::Close()`." ); |
1946 | } |
1947 | } |
1948 | } |
1949 | Ref(); |
1950 | rcg->Ref(); |
1951 | rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) { |
1952 | if (!s.ok()) { |
1953 | LOG(ERROR) << "Cleanup partition error: " << s; |
1954 | } |
1955 | rcg->Unref(); |
1956 | MarkRunCompletion(); |
1957 | Unref(); |
1958 | }); |
1959 | return s; |
1960 | } |
1961 | |
1962 | Status MasterSession::DoRunWithLocalExecution( |
1963 | CallOptions* opts, const RunStepRequestWrapper& req, |
1964 | MutableRunStepResponseWrapper* resp) { |
1965 | VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString(); |
1966 | PerStepState pss; |
1967 | pss.start_micros = Env::Default()->NowMicros(); |
1968 | auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); }); |
1969 | |
1970 | // Prepare. |
1971 | BuildGraphOptions bgopts; |
1972 | BuildBuildGraphOptions(req, session_opts_.config, &bgopts); |
1973 | ReffedClientGraph* rcg = nullptr; |
1974 | int64_t count; |
1975 | TF_RETURN_IF_ERROR(StartStep(bgopts, false, &rcg, &count)); |
1976 | |
1977 | // Unref "rcg" when out of scope. |
1978 | core::ScopedUnref unref(rcg); |
1979 | |
1980 | std::unique_ptr<DebuggerStateInterface> debugger_state; |
1981 | const DebugOptions& debug_options = req.options().debug_options(); |
1982 | |
1983 | if (!debug_options.debug_tensor_watch_opts().empty()) { |
1984 | TF_RETURN_IF_ERROR( |
1985 | CreateDebuggerState(debug_options, req, count, &debugger_state)); |
1986 | } |
1987 | TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg)); |
1988 | |
1989 | // Keeps the highest 8 bits 0x01: we reserve some bits of the |
1990 | // step_id for future use. |
1991 | uint64 step_id = NewStepId(rcg->collective_graph_key()); |
1992 | TRACEPRINTF("stepid %llu" , step_id); |
1993 | |
1994 | std::unique_ptr<ProfileHandler> ph; |
1995 | FillPerStepState(rcg, req.options(), step_id, count, &pss, &ph); |
1996 | |
1997 | if (pss.collect_partition_graphs && |
1998 | session_opts_.config.experimental().disable_output_partition_graphs()) { |
1999 | return errors::InvalidArgument( |
2000 | "RunOptions.output_partition_graphs() is not supported when " |
2001 | "disable_output_partition_graphs is true." ); |
2002 | } |
2003 | |
2004 | Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp, |
2005 | &cancellation_manager_, false); |
2006 | |
2007 | cleanup.release(); // MarkRunCompletion called in PostRunCleanup(). |
2008 | return PostRunCleanup(rcg, step_id, req.options(), &pss, ph, s, |
2009 | resp->mutable_metadata()); |
2010 | } |
2011 | |
2012 | Status MasterSession::MakeCallable(const MakeCallableRequest& req, |
2013 | MakeCallableResponse* resp) { |
2014 | UpdateLastAccessTime(); |
2015 | |
2016 | BuildGraphOptions opts; |
2017 | opts.callable_options = req.options(); |
2018 | opts.use_function_convention = false; |
2019 | |
2020 | ReffedClientGraph* callable; |
2021 | |
2022 | { |
2023 | mutex_lock l(mu_); |
2024 | if (closed_) { |
2025 | return errors::FailedPrecondition("Session is closed." ); |
2026 | } |
2027 | std::unique_ptr<ClientGraph> client_graph; |
2028 | TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph)); |
2029 | callable = new ReffedClientGraph(handle_, opts, std::move(client_graph), |
2030 | session_opts_, stats_publisher_factory_, |
2031 | false /* is_partial */, get_worker_cache(), |
2032 | !should_delete_worker_sessions_); |
2033 | } |
2034 | |
2035 | Status s = BuildAndRegisterPartitions(callable); |
2036 | if (!s.ok()) { |
2037 | callable->Unref(); |
2038 | return s; |
2039 | } |
2040 | |
2041 | uint64 handle; |
2042 | { |
2043 | mutex_lock l(mu_); |
2044 | handle = next_callable_handle_++; |
2045 | callables_[handle] = callable; |
2046 | } |
2047 | |
2048 | resp->set_handle(handle); |
2049 | return OkStatus(); |
2050 | } |
2051 | |
2052 | Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg, |
2053 | const RunCallableRequest& req, |
2054 | RunCallableResponse* resp) { |
2055 | VLOG(2) << "DoRunCallable req: " << req.DebugString(); |
2056 | PerStepState pss; |
2057 | pss.start_micros = Env::Default()->NowMicros(); |
2058 | auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); }); |
2059 | |
2060 | // Prepare. |
2061 | int64_t count = rcg->get_and_increment_execution_count(); |
2062 | |
2063 | const uint64 step_id = NewStepId(rcg->collective_graph_key()); |
2064 | TRACEPRINTF("stepid %llu" , step_id); |
2065 | |
2066 | const RunOptions& run_options = rcg->callable_options().run_options(); |
2067 | |
2068 | if (run_options.timeout_in_ms() != 0) { |
2069 | opts->SetTimeout(run_options.timeout_in_ms()); |
2070 | } |
2071 | |
2072 | std::unique_ptr<ProfileHandler> ph; |
2073 | FillPerStepState(rcg, run_options, step_id, count, &pss, &ph); |
2074 | Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp, |
2075 | &cancellation_manager_); |
2076 | cleanup.release(); // MarkRunCompletion called in PostRunCleanup(). |
2077 | return PostRunCleanup(rcg, step_id, run_options, &pss, ph, s, |
2078 | resp->mutable_metadata()); |
2079 | } |
2080 | |
2081 | Status MasterSession::RunCallable(CallOptions* opts, |
2082 | const RunCallableRequest& req, |
2083 | RunCallableResponse* resp) { |
2084 | UpdateLastAccessTime(); |
2085 | ReffedClientGraph* callable; |
2086 | { |
2087 | mutex_lock l(mu_); |
2088 | if (closed_) { |
2089 | return errors::FailedPrecondition("Session is closed." ); |
2090 | } |
2091 | int64_t handle = req.handle(); |
2092 | if (handle >= next_callable_handle_) { |
2093 | return errors::InvalidArgument("No such callable handle: " , handle); |
2094 | } |
2095 | auto iter = callables_.find(req.handle()); |
2096 | if (iter == callables_.end()) { |
2097 | return errors::InvalidArgument( |
2098 | "Attempted to run callable after handle was released: " , handle); |
2099 | } |
2100 | callable = iter->second; |
2101 | callable->Ref(); |
2102 | ++num_running_; |
2103 | } |
2104 | core::ScopedUnref unref_callable(callable); |
2105 | return DoRunCallable(opts, callable, req, resp); |
2106 | } |
2107 | |
2108 | Status MasterSession::ReleaseCallable(const ReleaseCallableRequest& req, |
2109 | ReleaseCallableResponse* resp) { |
2110 | UpdateLastAccessTime(); |
2111 | ReffedClientGraph* to_unref = nullptr; |
2112 | { |
2113 | mutex_lock l(mu_); |
2114 | auto iter = callables_.find(req.handle()); |
2115 | if (iter != callables_.end()) { |
2116 | to_unref = iter->second; |
2117 | callables_.erase(iter); |
2118 | } |
2119 | } |
2120 | if (to_unref != nullptr) { |
2121 | to_unref->Unref(); |
2122 | } |
2123 | return OkStatus(); |
2124 | } |
2125 | |
2126 | Status MasterSession::Close() { |
2127 | { |
2128 | mutex_lock l(mu_); |
2129 | closed_ = true; // All subsequent calls to Run() or Extend() will fail. |
2130 | } |
2131 | cancellation_manager_.StartCancel(); |
2132 | std::vector<ReffedClientGraph*> to_unref; |
2133 | { |
2134 | mutex_lock l(mu_); |
2135 | while (num_running_ != 0) { |
2136 | num_running_is_zero_.wait(l); |
2137 | } |
2138 | ClearRunsTable(&to_unref, &run_graphs_); |
2139 | ClearRunsTable(&to_unref, &partial_run_graphs_); |
2140 | ClearRunsTable(&to_unref, &callables_); |
2141 | } |
2142 | for (ReffedClientGraph* rcg : to_unref) rcg->Unref(); |
2143 | if (should_delete_worker_sessions_) { |
2144 | Status s = DeleteWorkerSessions(); |
2145 | if (!s.ok()) { |
2146 | LOG(WARNING) << s; |
2147 | } |
2148 | } |
2149 | return OkStatus(); |
2150 | } |
2151 | |
2152 | void MasterSession::GarbageCollect() { |
2153 | { |
2154 | mutex_lock l(mu_); |
2155 | closed_ = true; |
2156 | garbage_collected_ = true; |
2157 | } |
2158 | cancellation_manager_.StartCancel(); |
2159 | Unref(); |
2160 | } |
2161 | |
2162 | MasterSession::RunState::RunState(const std::vector<string>& input_names, |
2163 | const std::vector<string>& output_names, |
2164 | ReffedClientGraph* rcg, const uint64 step_id, |
2165 | const int64_t count) |
2166 | : rcg(rcg), step_id(step_id), count(count) { |
2167 | // Initially all the feeds and fetches are pending. |
2168 | for (auto& name : input_names) { |
2169 | pending_inputs[name] = false; |
2170 | } |
2171 | for (auto& name : output_names) { |
2172 | pending_outputs[name] = false; |
2173 | } |
2174 | } |
2175 | |
2176 | MasterSession::RunState::~RunState() { |
2177 | if (rcg) rcg->Unref(); |
2178 | } |
2179 | |
2180 | bool MasterSession::RunState::PendingDone() const { |
2181 | for (const auto& it : pending_inputs) { |
2182 | if (!it.second) return false; |
2183 | } |
2184 | for (const auto& it : pending_outputs) { |
2185 | if (!it.second) return false; |
2186 | } |
2187 | return true; |
2188 | } |
2189 | |
2190 | } // end namespace tensorflow |
2191 | |