1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/distributed_runtime/graph_mgr.h" |
17 | |
18 | #include <chrono> // NOLINT(build/c++11) |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/core/common_runtime/build_graph_options.h" |
22 | #include "tensorflow/core/common_runtime/constant_folding.h" |
23 | #include "tensorflow/core/common_runtime/debugger_state_interface.h" |
24 | #include "tensorflow/core/common_runtime/device.h" |
25 | #include "tensorflow/core/common_runtime/device_mgr.h" |
26 | #include "tensorflow/core/common_runtime/function.h" |
27 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
28 | #include "tensorflow/core/common_runtime/graph_optimizer.h" |
29 | #include "tensorflow/core/common_runtime/memory_types.h" |
30 | #include "tensorflow/core/common_runtime/optimization_registry.h" |
31 | #include "tensorflow/core/common_runtime/process_util.h" |
32 | #include "tensorflow/core/common_runtime/rendezvous_util.h" |
33 | #include "tensorflow/core/common_runtime/step_stats_collector.h" |
34 | #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" |
35 | #include "tensorflow/core/framework/cancellation.h" |
36 | #include "tensorflow/core/framework/collective.h" |
37 | #include "tensorflow/core/framework/log_memory.h" |
38 | #include "tensorflow/core/framework/metrics.h" |
39 | #include "tensorflow/core/framework/node_def.pb.h" |
40 | #include "tensorflow/core/framework/node_def_util.h" |
41 | #include "tensorflow/core/framework/versions.pb.h" |
42 | #include "tensorflow/core/graph/graph.h" |
43 | #include "tensorflow/core/graph/graph_partition.h" |
44 | #include "tensorflow/core/graph/validate.h" |
45 | #include "tensorflow/core/lib/core/errors.h" |
46 | #include "tensorflow/core/lib/strings/stringprintf.h" |
47 | #include "tensorflow/core/platform/env.h" |
48 | #include "tensorflow/core/platform/logging.h" |
49 | #include "tensorflow/core/platform/mutex.h" |
50 | #include "tensorflow/core/platform/tracing.h" |
51 | #include "tensorflow/core/platform/types.h" |
52 | #include "tensorflow/core/profiler/lib/connected_traceme.h" |
53 | #include "tensorflow/core/profiler/lib/traceme_encode.h" |
54 | #include "tensorflow/core/protobuf/worker.pb.h" |
55 | #include "tensorflow/core/util/env_var.h" |
56 | |
57 | namespace tensorflow { |
58 | |
59 | GraphMgr::GraphMgr(const WorkerEnv* worker_env, const DeviceMgr* device_mgr) |
60 | : worker_env_(worker_env), device_mgr_(device_mgr), table_(5) { |
61 | // The default value of sync_on_finish will be flipped soon and this |
62 | // environment variable will be removed as well. |
63 | Status status = |
64 | ReadBoolFromEnvVar("TF_SYNC_ON_FINISH" , true, &sync_on_finish_); |
65 | if (!status.ok()) { |
66 | LOG(ERROR) << status.error_message(); |
67 | } |
68 | } |
69 | |
70 | GraphMgr::~GraphMgr() { |
71 | for (const auto& p : table_) p.second->Unref(); |
72 | } |
73 | |
74 | GraphMgr::Item::~Item() { |
75 | for (const auto& unit : this->units) { |
76 | CHECK_NOTNULL(unit.device); |
77 | if (!graph_mgr->skip_cost_models_) { |
78 | graph_mgr->cost_model_manager_.RemoveCostModelForGraph(unit.graph.get()); |
79 | } |
80 | delete unit.root; |
81 | unit.device->op_segment()->RemoveHold(this->session); |
82 | } |
83 | } |
84 | |
85 | // NOTE: node->device_name() is not set by GraphConstructor. We |
86 | // expects that NodeDef in GraphDef given to workers fully specifies |
87 | // device names. |
88 | static string SplitByDevice(const Node* node) { |
89 | return node->assigned_device_name(); |
90 | } |
91 | |
92 | // Validates "gdef" device specifications. |
93 | static Status ValidateGraphDefForDevices(const GraphDef& gdef) { |
94 | DeviceNameUtils::ParsedName parsed; |
95 | for (const auto& ndef : gdef.node()) { |
96 | if (!DeviceNameUtils::ParseFullName(ndef.device(), &parsed)) { |
97 | return errors::InvalidArgument("Missing device name in: " , |
98 | FormatNodeDefForError(ndef)); |
99 | } |
100 | } |
101 | return OkStatus(); |
102 | } |
103 | |
104 | Status GraphMgr::DecorateAndPublishGraphForDebug( |
105 | const DebugOptions& debug_options, Graph* graph, Device* device) { |
106 | std::unique_ptr<DebugGraphDecoratorInterface> decorator; |
107 | TF_RETURN_IF_ERROR( |
108 | DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator)); |
109 | TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device)); |
110 | TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name())); |
111 | return OkStatus(); |
112 | } |
113 | |
114 | // Creates executors given a graph definition "gdef" of a "session". |
115 | // If a node in "gdef" is shared by other graphs in "session", the |
116 | // same op kernel is reused. E.g., typically a params node is shared |
117 | // by multiple graphs in a session. |
118 | // |
119 | // If "gdef" is assigned to multiple devices, extra nodes (e.g., |
120 | // send/recv nodes) maybe added. The extra nodes' name are generated |
121 | // by calling "new_name(old_name)". |
122 | // |
123 | // "executors" are filled with one executor per device if success and |
124 | // the caller takes the ownership of returned executors. |
125 | Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef, |
126 | const GraphOptions& graph_options, |
127 | const DebugOptions& debug_options, |
128 | const ConfigProto& config_proto, |
129 | int64_t collective_graph_key, WorkerSession* session, |
130 | DistributedFunctionLibraryRuntime* cluster_flr, |
131 | Item* item) { |
132 | item->session = handle; |
133 | item->collective_graph_key = collective_graph_key; |
134 | item->lib_def.reset( |
135 | new FunctionLibraryDefinition(OpRegistry::Global(), gdef.library())); |
136 | |
137 | TF_RETURN_IF_ERROR(ValidateGraphDefForDevices(gdef)); |
138 | |
139 | // We don't explicitly Validate the graph def because ConvertGraphDefToGraph |
140 | // does that below. |
141 | item->proc_flr.reset(new ProcessFunctionLibraryRuntime( |
142 | device_mgr_, worker_env_->env, /*config=*/&config_proto, |
143 | gdef.versions().producer(), item->lib_def.get(), |
144 | graph_options.optimizer_options(), worker_env_->compute_pool, cluster_flr, |
145 | /*session_metadata=*/nullptr, |
146 | Rendezvous::Factory{ |
147 | [this, session](const int64_t step_id, const DeviceMgr*, |
148 | Rendezvous** r) -> Status { |
149 | auto* remote_r = this->worker_env_->rendezvous_mgr->Find(step_id); |
150 | TF_RETURN_IF_ERROR(remote_r->Initialize(session)); |
151 | *r = remote_r; |
152 | return OkStatus(); |
153 | }, |
154 | [this](const int64_t step_id) { |
155 | this->worker_env_->rendezvous_mgr->Cleanup(step_id); |
156 | return OkStatus(); |
157 | }})); |
158 | |
159 | // Constructs the graph out of "gdef". |
160 | Graph graph(OpRegistry::Global()); |
161 | GraphConstructorOptions opts; |
162 | opts.allow_internal_ops = true; |
163 | opts.expect_device_spec = true; |
164 | opts.validate_nodes = true; |
165 | TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, gdef, &graph)); |
166 | |
167 | // Splits "graph" into multiple subgraphs by device names. |
168 | std::unordered_map<string, GraphDef> partitions; |
169 | PartitionOptions popts; |
170 | popts.node_to_loc = SplitByDevice; |
171 | popts.new_name = [this](const string& prefix) { |
172 | mutex_lock l(mu_); |
173 | return strings::StrCat(prefix, "_G" , next_id_++); |
174 | }; |
175 | popts.get_incarnation = [this](const string& name) -> int64 { |
176 | Device* device = nullptr; |
177 | Status s = device_mgr_->LookupDevice(name, &device); |
178 | if (s.ok()) { |
179 | return device->attributes().incarnation(); |
180 | } else { |
181 | return PartitionOptions::kIllegalIncarnation; |
182 | } |
183 | }; |
184 | popts.flib_def = item->lib_def.get(); |
185 | popts.control_flow_added = true; |
186 | popts.scheduling_for_recvs = graph_options.enable_recv_scheduling(); |
187 | TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions)); |
188 | if (popts.scheduling_for_recvs) { |
189 | TF_RETURN_IF_ERROR(AddControlEdges(popts, &partitions)); |
190 | } |
191 | |
192 | std::unordered_map<string, std::unique_ptr<Graph>> partition_graphs; |
193 | for (auto& partition : partitions) { |
194 | std::unique_ptr<Graph> device_graph(new Graph(OpRegistry::Global())); |
195 | GraphConstructorOptions device_opts; |
196 | // There are internal operations (e.g., send/recv) that we now allow. |
197 | device_opts.allow_internal_ops = true; |
198 | device_opts.expect_device_spec = true; |
199 | TF_RETURN_IF_ERROR(ConvertGraphDefToGraph( |
200 | device_opts, std::move(partition.second), device_graph.get())); |
201 | partition_graphs.emplace(partition.first, std::move(device_graph)); |
202 | } |
203 | |
204 | GraphOptimizationPassOptions optimization_options; |
205 | optimization_options.flib_def = item->lib_def.get(); |
206 | optimization_options.partition_graphs = &partition_graphs; |
207 | TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( |
208 | OptimizationPassRegistry::POST_PARTITIONING, optimization_options)); |
209 | |
210 | LocalExecutorParams params; |
211 | |
212 | item->units.reserve(partitions.size()); |
213 | item->graph_mgr = this; |
214 | const auto& optimizer_opts = graph_options.optimizer_options(); |
215 | GraphOptimizer optimizer(optimizer_opts); |
216 | for (auto& p : partition_graphs) { |
217 | const string& device_name = p.first; |
218 | std::unique_ptr<Graph>& subgraph = p.second; |
219 | item->units.resize(item->units.size() + 1); |
220 | ExecutionUnit* unit = &(item->units.back()); |
221 | |
222 | // Find the device. |
223 | Status s = device_mgr_->LookupDevice(device_name, &unit->device); |
224 | if (!s.ok()) { |
225 | // Remove the empty unit from the item as the item destructor wants all |
226 | // units to have valid devices. |
227 | item->units.pop_back(); |
228 | return s; |
229 | } |
230 | |
231 | // Give the device an opportunity to rewrite its subgraph. |
232 | TF_RETURN_IF_ERROR(unit->device->MaybeRewriteGraph(&subgraph)); |
233 | |
234 | // Top-level nodes in the graph uses the op segment to cache |
235 | // kernels. Therefore, as long as the executor is alive, we need |
236 | // to ensure the kernels cached for the session are alive. |
237 | auto opseg = unit->device->op_segment(); |
238 | opseg->AddHold(handle); |
239 | |
240 | // Function library runtime. |
241 | FunctionLibraryRuntime* lib = item->proc_flr->GetFLR(unit->device->name()); |
242 | if (lib == nullptr) { |
243 | return errors::InvalidArgument("Cannot find FLR for device: " , |
244 | unit->device->name()); |
245 | } |
246 | |
247 | // Construct the root executor for the subgraph. |
248 | params.device = unit->device; |
249 | params.function_library = lib; |
250 | params.create_kernel = |
251 | [handle, lib, opseg](const std::shared_ptr<const NodeProperties>& props, |
252 | OpKernel** kernel) { |
253 | // NOTE(mrry): We must not share function kernels (implemented |
254 | // using `CallOp`) between subgraphs, because `CallOp::handle_` |
255 | // is tied to a particular subgraph. Even if the function itself |
256 | // is stateful, the `CallOp` that invokes it is not. |
257 | if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) { |
258 | return lib->CreateKernel(props, kernel); |
259 | } |
260 | auto create_fn = [lib, &props](OpKernel** kernel) { |
261 | return lib->CreateKernel(props, kernel); |
262 | }; |
263 | // Kernels created for subgraph nodes need to be cached. On |
264 | // cache miss, create_fn() is invoked to create a kernel based |
265 | // on the function library here + global op registry. |
266 | return opseg->FindOrCreate(handle, props->node_def.name(), kernel, |
267 | create_fn); |
268 | }; |
269 | params.delete_kernel = [lib](OpKernel* kernel) { |
270 | if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) { |
271 | delete kernel; |
272 | } |
273 | }; |
274 | |
275 | optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph, |
276 | GraphOptimizer::Options()); |
277 | |
278 | // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph. |
279 | if (!debug_options.debug_tensor_watch_opts().empty()) { |
280 | TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug( |
281 | debug_options, subgraph.get(), params.device)); |
282 | } |
283 | |
284 | TF_RETURN_IF_ERROR( |
285 | EnsureMemoryTypes(DeviceType(unit->device->device_type()), |
286 | unit->device->name(), subgraph.get())); |
287 | unit->graph = std::move(subgraph); |
288 | unit->build_cost_model = graph_options.build_cost_model(); |
289 | if (unit->build_cost_model > 0) { |
290 | skip_cost_models_ = false; |
291 | } |
292 | TF_RETURN_IF_ERROR(NewLocalExecutor(params, *unit->graph, &unit->root)); |
293 | } |
294 | return OkStatus(); |
295 | } |
296 | |
297 | Status GraphMgr::Register(const string& handle, const GraphDef& gdef, |
298 | const GraphOptions& graph_options, |
299 | const DebugOptions& debug_options, |
300 | const ConfigProto& config_proto, |
301 | int64_t collective_graph_key, WorkerSession* session, |
302 | DistributedFunctionLibraryRuntime* cluster_flr, |
303 | string* graph_handle) { |
304 | Item* item = new Item; |
305 | Status s = InitItem(handle, gdef, graph_options, debug_options, config_proto, |
306 | collective_graph_key, session, cluster_flr, item); |
307 | if (!s.ok()) { |
308 | item->Unref(); |
309 | return s; |
310 | } |
311 | |
312 | // Inserts one item into table_. |
313 | { |
314 | mutex_lock l(mu_); |
315 | *graph_handle = |
316 | strings::Printf("%016llx" , static_cast<long long>(++next_id_)); |
317 | item->handle = *graph_handle; |
318 | CHECK(table_.insert({*graph_handle, item}).second); |
319 | } |
320 | return OkStatus(); |
321 | } |
322 | |
323 | Status GraphMgr::Deregister(const string& handle) { |
324 | Item* item = nullptr; |
325 | // Removes one item from table_. |
326 | { |
327 | mutex_lock l(mu_); |
328 | auto iter = table_.find(handle); |
329 | if (iter == table_.end()) { |
330 | return errors::Aborted("Graph handle is not found: " , handle, |
331 | ". Possibly, this worker just restarted." ); |
332 | } |
333 | item = iter->second; |
334 | table_.erase(iter); |
335 | } |
336 | item->Unref(); |
337 | return OkStatus(); |
338 | } |
339 | |
340 | Status GraphMgr::DeregisterAll() { |
341 | std::vector<Item*> items; |
342 | // Removes all items from table_. |
343 | { |
344 | mutex_lock l(mu_); |
345 | for (const auto& entry : table_) { |
346 | items.push_back(entry.second); |
347 | } |
348 | table_.clear(); |
349 | } |
350 | for (auto item : items) { |
351 | item->Unref(); |
352 | } |
353 | return OkStatus(); |
354 | } |
355 | |
356 | Status GraphMgr::SendInputs(const int64_t step_id, const NamedTensors& in) { |
357 | Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); |
358 | std::vector<string> keys; |
359 | std::vector<Tensor> tensors_to_send; |
360 | keys.reserve(in.size()); |
361 | tensors_to_send.reserve(in.size()); |
362 | size_t input_size = 0; |
363 | for (const auto& p : in) { |
364 | keys.push_back(p.first); |
365 | tensors_to_send.push_back(p.second); |
366 | input_size += p.second.AllocatedBytes(); |
367 | } |
368 | metrics::RecordGraphInputTensors(input_size); |
369 | Status s = |
370 | SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send); |
371 | rendezvous->Unref(); |
372 | return s; |
373 | } |
374 | |
375 | Status GraphMgr::RecvOutputs(const int64_t step_id, NamedTensors* out) { |
376 | Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); |
377 | Status s = RecvOutputsFromRendezvous(rendezvous, out, Rendezvous::Args()); |
378 | rendezvous->Unref(); |
379 | if (!s.ok()) { |
380 | // Failing to fetch the outputs should not be possible, so rewrite the error |
381 | // status to an INTERNAL error. |
382 | s = errors::Internal("Failed to fetch outputs for step " , step_id, |
383 | ". (Original error message: " , s.error_message(), ")" ); |
384 | } |
385 | size_t output_size = 0; |
386 | for (auto& p : *out) { |
387 | output_size += p.second.AllocatedBytes(); |
388 | } |
389 | metrics::RecordGraphOutputTensors(output_size); |
390 | return s; |
391 | } |
392 | |
393 | void GraphMgr::RecvOutputsAsync(const int64_t step_id, NamedTensors* out, |
394 | StatusCallback done) { |
395 | Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); |
396 | std::vector<string> keys; |
397 | std::vector<Tensor>* received_keys = new std::vector<Tensor>; |
398 | keys.reserve(out->size()); |
399 | received_keys->reserve(out->size()); |
400 | for (const auto& p : *out) { |
401 | keys.push_back(p.first); |
402 | received_keys->push_back(p.second); |
403 | } |
404 | RecvOutputsFromRendezvousAsync( |
405 | rendezvous, nullptr, {}, keys, received_keys, |
406 | [done, rendezvous, received_keys, out, keys](const Status s) { |
407 | rendezvous->Unref(); |
408 | size_t output_size = 0; |
409 | for (int i = 0, end = keys.size(); i < end; ++i) { |
410 | (*out)[keys[i]] = (*received_keys)[i]; |
411 | output_size += (*out)[keys[i]].AllocatedBytes(); |
412 | } |
413 | metrics::RecordGraphOutputTensors(output_size); |
414 | delete received_keys; |
415 | done(s); |
416 | }); |
417 | } |
418 | |
419 | void GraphMgr::ExecuteAsync( |
420 | const string& handle, const int64_t step_id, const ExecutorOpts& opts, |
421 | const NamedTensors& in, WorkerSession* session, |
422 | StepStatsCollector* collector, MutableRunGraphResponseWrapper* response, |
423 | CancellationManager* cancellation_manager, |
424 | CoordinationServiceAgent* coordination_service_agent, StatusCallback done) { |
425 | const uint64 start_time_usecs = Env::Default()->NowMicros(); |
426 | profiler::TraceMeProducer activity( |
427 | // To TraceMeConsumers in ExecutorState::Process/Finish or RunGraphDone. |
428 | [step_id] { |
429 | return profiler::TraceMeEncode( |
430 | "RunGraph" , {{"id" , step_id}, {"_r" , 1} /*root_event*/}); |
431 | }, |
432 | profiler::ContextType::kTfExecutor, step_id, |
433 | profiler::TraceMeLevel::kInfo); |
434 | // Lookup an item. Holds one ref while executing. |
435 | Item* item = nullptr; |
436 | { |
437 | mutex_lock l(mu_); |
438 | auto iter = table_.find(handle); |
439 | if (iter != table_.end()) { |
440 | item = iter->second; |
441 | item->Ref(); |
442 | } |
443 | } |
444 | |
445 | if (item == nullptr) { |
446 | done(errors::Aborted("Graph handle is not found: " , handle)); |
447 | return; |
448 | } |
449 | |
450 | CostGraphDef* cost_graph = nullptr; |
451 | if (response != nullptr) { |
452 | cost_graph = response->mutable_cost_graph(); |
453 | if (opts.record_partition_graphs()) { |
454 | for (const ExecutionUnit& unit : item->units) { |
455 | GraphDef graph_def; |
456 | unit.graph->ToGraphDef(&graph_def); |
457 | response->AddPartitionGraph(graph_def); |
458 | } |
459 | } |
460 | } |
461 | |
462 | RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); |
463 | Status s = rendezvous->Initialize(session); |
464 | CollectiveExecutor::Handle* ce_handle = |
465 | item->collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey |
466 | ? new CollectiveExecutor::Handle( |
467 | worker_env_->collective_executor_mgr->FindOrCreate(step_id), |
468 | true) |
469 | : nullptr; |
470 | // Sends values specified by the caller. |
471 | size_t input_size = 0; |
472 | if (s.ok()) { |
473 | std::vector<string> keys; |
474 | std::vector<Tensor> tensors_to_send; |
475 | keys.reserve(in.size()); |
476 | tensors_to_send.reserve(in.size()); |
477 | for (auto& p : in) { |
478 | keys.push_back(p.first); |
479 | tensors_to_send.push_back(p.second); |
480 | input_size += p.second.AllocatedBytes(); |
481 | } |
482 | s = SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send); |
483 | } |
484 | |
485 | if (!s.ok()) { |
486 | done(s); |
487 | delete ce_handle; |
488 | item->Unref(); |
489 | rendezvous->Unref(); |
490 | return; |
491 | } |
492 | |
493 | StartParallelExecutors( |
494 | handle, step_id, item, rendezvous, ce_handle, collector, cost_graph, |
495 | cancellation_manager, session, start_time_usecs, |
496 | coordination_service_agent, |
497 | [item, rendezvous, ce_handle, done, start_time_usecs, input_size, |
498 | step_id](const Status& s) { |
499 | profiler::TraceMeConsumer activity( |
500 | // From TraceMeProducer in GraphMgr::ExecuteAsync. |
501 | [step_id] { |
502 | return profiler::TraceMeEncode("RunGraphDone" , {{"id" , step_id}}); |
503 | }, |
504 | profiler::ContextType::kTfExecutor, step_id, |
505 | profiler::TraceMeLevel::kInfo); |
506 | done(s); |
507 | metrics::RecordGraphInputTensors(input_size); |
508 | metrics::UpdateGraphExecTime(Env::Default()->NowMicros() - |
509 | start_time_usecs); |
510 | rendezvous->Unref(); |
511 | item->Unref(); |
512 | delete ce_handle; |
513 | }); |
514 | } |
515 | |
516 | void GraphMgr::StartParallelExecutors( |
517 | const string& handle, int64_t step_id, Item* item, Rendezvous* rendezvous, |
518 | CollectiveExecutor::Handle* ce_handle, StepStatsCollector* collector, |
519 | CostGraphDef* cost_graph, CancellationManager* cancellation_manager, |
520 | WorkerSession* session, int64_t start_time_usecs, |
521 | CoordinationServiceAgent* coordination_service_agent, StatusCallback done) { |
522 | const int num_units = item->units.size(); |
523 | CHECK_GE(num_units, 1); |
524 | ScopedStepContainer* step_container = new ScopedStepContainer( |
525 | step_id, |
526 | [this](const string& name) { device_mgr_->ClearContainers({name}); }); |
527 | // NOTE: Transfer one ref of rendezvous and item. |
528 | ExecutorBarrier* barrier = |
529 | new ExecutorBarrier(num_units, rendezvous, |
530 | [this, item, collector, cost_graph, step_container, |
531 | done](const Status& s) { |
532 | BuildCostModel(item, collector, cost_graph); |
533 | done(s); |
534 | delete step_container; |
535 | }); |
536 | Executor::Args args; |
537 | args.step_id = step_id; |
538 | args.rendezvous = rendezvous; |
539 | args.collective_executor = ce_handle ? ce_handle->get() : nullptr; |
540 | args.cancellation_manager = cancellation_manager; |
541 | args.stats_collector = collector; |
542 | args.step_container = step_container; |
543 | args.sync_on_finish = sync_on_finish_; |
544 | args.start_time_usecs = start_time_usecs; |
545 | args.coordination_service_agent = coordination_service_agent; |
546 | |
547 | if (LogMemory::IsEnabled()) { |
548 | LogMemory::RecordStep(args.step_id, handle); |
549 | } |
550 | thread::ThreadPool* pool = worker_env_->compute_pool; |
551 | using std::placeholders::_1; |
552 | // Line below is equivalent to this code, but does one less indirect call: |
553 | // args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); }; |
554 | auto default_runner = std::bind(&thread::ThreadPool::Schedule, pool, _1); |
555 | for (const auto& unit : item->units) { |
556 | // TODO(zhengxq): if the device picks its own threadpool, we need to assign |
557 | // less threads to the main compute pool by default. |
558 | thread::ThreadPool* device_thread_pool = |
559 | unit.device->tensorflow_device_thread_pool(); |
560 | if (!device_thread_pool) { |
561 | args.runner = default_runner; |
562 | } else { |
563 | args.runner = |
564 | std::bind(&thread::ThreadPool::Schedule, device_thread_pool, _1); |
565 | } |
566 | unit.root->RunAsync(args, barrier->Get()); |
567 | } |
568 | } |
569 | |
570 | void GraphMgr::BuildCostModel(Item* item, StepStatsCollector* collector, |
571 | CostGraphDef* cost_graph) { |
572 | if (collector && !skip_cost_models_) { |
573 | // Build the cost model |
574 | std::unordered_map<string, const Graph*> device_to_graph; |
575 | for (const auto& unit : item->units) { |
576 | if (unit.build_cost_model > 0) { |
577 | device_to_graph[unit.device->name()] = unit.graph.get(); |
578 | } |
579 | } |
580 | collector->BuildCostModel(&cost_model_manager_, device_to_graph); |
581 | |
582 | if (cost_graph != nullptr) { |
583 | for (const auto& unit : item->units) { |
584 | cost_model_manager_.AddToCostGraphDef(unit.graph.get(), cost_graph) |
585 | .IgnoreError(); |
586 | } |
587 | } |
588 | } |
589 | } |
590 | |
591 | } // end namespace tensorflow |
592 | |