1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/distributed_runtime/graph_mgr.h"
17
18#include <chrono> // NOLINT(build/c++11)
19#include <vector>
20
21#include "tensorflow/core/common_runtime/build_graph_options.h"
22#include "tensorflow/core/common_runtime/constant_folding.h"
23#include "tensorflow/core/common_runtime/debugger_state_interface.h"
24#include "tensorflow/core/common_runtime/device.h"
25#include "tensorflow/core/common_runtime/device_mgr.h"
26#include "tensorflow/core/common_runtime/function.h"
27#include "tensorflow/core/common_runtime/graph_constructor.h"
28#include "tensorflow/core/common_runtime/graph_optimizer.h"
29#include "tensorflow/core/common_runtime/memory_types.h"
30#include "tensorflow/core/common_runtime/optimization_registry.h"
31#include "tensorflow/core/common_runtime/process_util.h"
32#include "tensorflow/core/common_runtime/rendezvous_util.h"
33#include "tensorflow/core/common_runtime/step_stats_collector.h"
34#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
35#include "tensorflow/core/framework/cancellation.h"
36#include "tensorflow/core/framework/collective.h"
37#include "tensorflow/core/framework/log_memory.h"
38#include "tensorflow/core/framework/metrics.h"
39#include "tensorflow/core/framework/node_def.pb.h"
40#include "tensorflow/core/framework/node_def_util.h"
41#include "tensorflow/core/framework/versions.pb.h"
42#include "tensorflow/core/graph/graph.h"
43#include "tensorflow/core/graph/graph_partition.h"
44#include "tensorflow/core/graph/validate.h"
45#include "tensorflow/core/lib/core/errors.h"
46#include "tensorflow/core/lib/strings/stringprintf.h"
47#include "tensorflow/core/platform/env.h"
48#include "tensorflow/core/platform/logging.h"
49#include "tensorflow/core/platform/mutex.h"
50#include "tensorflow/core/platform/tracing.h"
51#include "tensorflow/core/platform/types.h"
52#include "tensorflow/core/profiler/lib/connected_traceme.h"
53#include "tensorflow/core/profiler/lib/traceme_encode.h"
54#include "tensorflow/core/protobuf/worker.pb.h"
55#include "tensorflow/core/util/env_var.h"
56
57namespace tensorflow {
58
59GraphMgr::GraphMgr(const WorkerEnv* worker_env, const DeviceMgr* device_mgr)
60 : worker_env_(worker_env), device_mgr_(device_mgr), table_(5) {
61 // The default value of sync_on_finish will be flipped soon and this
62 // environment variable will be removed as well.
63 Status status =
64 ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
65 if (!status.ok()) {
66 LOG(ERROR) << status.error_message();
67 }
68}
69
70GraphMgr::~GraphMgr() {
71 for (const auto& p : table_) p.second->Unref();
72}
73
74GraphMgr::Item::~Item() {
75 for (const auto& unit : this->units) {
76 CHECK_NOTNULL(unit.device);
77 if (!graph_mgr->skip_cost_models_) {
78 graph_mgr->cost_model_manager_.RemoveCostModelForGraph(unit.graph.get());
79 }
80 delete unit.root;
81 unit.device->op_segment()->RemoveHold(this->session);
82 }
83}
84
85// NOTE: node->device_name() is not set by GraphConstructor. We
86// expects that NodeDef in GraphDef given to workers fully specifies
87// device names.
88static string SplitByDevice(const Node* node) {
89 return node->assigned_device_name();
90}
91
92// Validates "gdef" device specifications.
93static Status ValidateGraphDefForDevices(const GraphDef& gdef) {
94 DeviceNameUtils::ParsedName parsed;
95 for (const auto& ndef : gdef.node()) {
96 if (!DeviceNameUtils::ParseFullName(ndef.device(), &parsed)) {
97 return errors::InvalidArgument("Missing device name in: ",
98 FormatNodeDefForError(ndef));
99 }
100 }
101 return OkStatus();
102}
103
104Status GraphMgr::DecorateAndPublishGraphForDebug(
105 const DebugOptions& debug_options, Graph* graph, Device* device) {
106 std::unique_ptr<DebugGraphDecoratorInterface> decorator;
107 TF_RETURN_IF_ERROR(
108 DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
109 TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
110 TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
111 return OkStatus();
112}
113
114// Creates executors given a graph definition "gdef" of a "session".
115// If a node in "gdef" is shared by other graphs in "session", the
116// same op kernel is reused. E.g., typically a params node is shared
117// by multiple graphs in a session.
118//
119// If "gdef" is assigned to multiple devices, extra nodes (e.g.,
120// send/recv nodes) maybe added. The extra nodes' name are generated
121// by calling "new_name(old_name)".
122//
123// "executors" are filled with one executor per device if success and
124// the caller takes the ownership of returned executors.
125Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef,
126 const GraphOptions& graph_options,
127 const DebugOptions& debug_options,
128 const ConfigProto& config_proto,
129 int64_t collective_graph_key, WorkerSession* session,
130 DistributedFunctionLibraryRuntime* cluster_flr,
131 Item* item) {
132 item->session = handle;
133 item->collective_graph_key = collective_graph_key;
134 item->lib_def.reset(
135 new FunctionLibraryDefinition(OpRegistry::Global(), gdef.library()));
136
137 TF_RETURN_IF_ERROR(ValidateGraphDefForDevices(gdef));
138
139 // We don't explicitly Validate the graph def because ConvertGraphDefToGraph
140 // does that below.
141 item->proc_flr.reset(new ProcessFunctionLibraryRuntime(
142 device_mgr_, worker_env_->env, /*config=*/&config_proto,
143 gdef.versions().producer(), item->lib_def.get(),
144 graph_options.optimizer_options(), worker_env_->compute_pool, cluster_flr,
145 /*session_metadata=*/nullptr,
146 Rendezvous::Factory{
147 [this, session](const int64_t step_id, const DeviceMgr*,
148 Rendezvous** r) -> Status {
149 auto* remote_r = this->worker_env_->rendezvous_mgr->Find(step_id);
150 TF_RETURN_IF_ERROR(remote_r->Initialize(session));
151 *r = remote_r;
152 return OkStatus();
153 },
154 [this](const int64_t step_id) {
155 this->worker_env_->rendezvous_mgr->Cleanup(step_id);
156 return OkStatus();
157 }}));
158
159 // Constructs the graph out of "gdef".
160 Graph graph(OpRegistry::Global());
161 GraphConstructorOptions opts;
162 opts.allow_internal_ops = true;
163 opts.expect_device_spec = true;
164 opts.validate_nodes = true;
165 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, gdef, &graph));
166
167 // Splits "graph" into multiple subgraphs by device names.
168 std::unordered_map<string, GraphDef> partitions;
169 PartitionOptions popts;
170 popts.node_to_loc = SplitByDevice;
171 popts.new_name = [this](const string& prefix) {
172 mutex_lock l(mu_);
173 return strings::StrCat(prefix, "_G", next_id_++);
174 };
175 popts.get_incarnation = [this](const string& name) -> int64 {
176 Device* device = nullptr;
177 Status s = device_mgr_->LookupDevice(name, &device);
178 if (s.ok()) {
179 return device->attributes().incarnation();
180 } else {
181 return PartitionOptions::kIllegalIncarnation;
182 }
183 };
184 popts.flib_def = item->lib_def.get();
185 popts.control_flow_added = true;
186 popts.scheduling_for_recvs = graph_options.enable_recv_scheduling();
187 TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions));
188 if (popts.scheduling_for_recvs) {
189 TF_RETURN_IF_ERROR(AddControlEdges(popts, &partitions));
190 }
191
192 std::unordered_map<string, std::unique_ptr<Graph>> partition_graphs;
193 for (auto& partition : partitions) {
194 std::unique_ptr<Graph> device_graph(new Graph(OpRegistry::Global()));
195 GraphConstructorOptions device_opts;
196 // There are internal operations (e.g., send/recv) that we now allow.
197 device_opts.allow_internal_ops = true;
198 device_opts.expect_device_spec = true;
199 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
200 device_opts, std::move(partition.second), device_graph.get()));
201 partition_graphs.emplace(partition.first, std::move(device_graph));
202 }
203
204 GraphOptimizationPassOptions optimization_options;
205 optimization_options.flib_def = item->lib_def.get();
206 optimization_options.partition_graphs = &partition_graphs;
207 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
208 OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
209
210 LocalExecutorParams params;
211
212 item->units.reserve(partitions.size());
213 item->graph_mgr = this;
214 const auto& optimizer_opts = graph_options.optimizer_options();
215 GraphOptimizer optimizer(optimizer_opts);
216 for (auto& p : partition_graphs) {
217 const string& device_name = p.first;
218 std::unique_ptr<Graph>& subgraph = p.second;
219 item->units.resize(item->units.size() + 1);
220 ExecutionUnit* unit = &(item->units.back());
221
222 // Find the device.
223 Status s = device_mgr_->LookupDevice(device_name, &unit->device);
224 if (!s.ok()) {
225 // Remove the empty unit from the item as the item destructor wants all
226 // units to have valid devices.
227 item->units.pop_back();
228 return s;
229 }
230
231 // Give the device an opportunity to rewrite its subgraph.
232 TF_RETURN_IF_ERROR(unit->device->MaybeRewriteGraph(&subgraph));
233
234 // Top-level nodes in the graph uses the op segment to cache
235 // kernels. Therefore, as long as the executor is alive, we need
236 // to ensure the kernels cached for the session are alive.
237 auto opseg = unit->device->op_segment();
238 opseg->AddHold(handle);
239
240 // Function library runtime.
241 FunctionLibraryRuntime* lib = item->proc_flr->GetFLR(unit->device->name());
242 if (lib == nullptr) {
243 return errors::InvalidArgument("Cannot find FLR for device: ",
244 unit->device->name());
245 }
246
247 // Construct the root executor for the subgraph.
248 params.device = unit->device;
249 params.function_library = lib;
250 params.create_kernel =
251 [handle, lib, opseg](const std::shared_ptr<const NodeProperties>& props,
252 OpKernel** kernel) {
253 // NOTE(mrry): We must not share function kernels (implemented
254 // using `CallOp`) between subgraphs, because `CallOp::handle_`
255 // is tied to a particular subgraph. Even if the function itself
256 // is stateful, the `CallOp` that invokes it is not.
257 if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) {
258 return lib->CreateKernel(props, kernel);
259 }
260 auto create_fn = [lib, &props](OpKernel** kernel) {
261 return lib->CreateKernel(props, kernel);
262 };
263 // Kernels created for subgraph nodes need to be cached. On
264 // cache miss, create_fn() is invoked to create a kernel based
265 // on the function library here + global op registry.
266 return opseg->FindOrCreate(handle, props->node_def.name(), kernel,
267 create_fn);
268 };
269 params.delete_kernel = [lib](OpKernel* kernel) {
270 if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) {
271 delete kernel;
272 }
273 };
274
275 optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph,
276 GraphOptimizer::Options());
277
278 // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
279 if (!debug_options.debug_tensor_watch_opts().empty()) {
280 TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
281 debug_options, subgraph.get(), params.device));
282 }
283
284 TF_RETURN_IF_ERROR(
285 EnsureMemoryTypes(DeviceType(unit->device->device_type()),
286 unit->device->name(), subgraph.get()));
287 unit->graph = std::move(subgraph);
288 unit->build_cost_model = graph_options.build_cost_model();
289 if (unit->build_cost_model > 0) {
290 skip_cost_models_ = false;
291 }
292 TF_RETURN_IF_ERROR(NewLocalExecutor(params, *unit->graph, &unit->root));
293 }
294 return OkStatus();
295}
296
297Status GraphMgr::Register(const string& handle, const GraphDef& gdef,
298 const GraphOptions& graph_options,
299 const DebugOptions& debug_options,
300 const ConfigProto& config_proto,
301 int64_t collective_graph_key, WorkerSession* session,
302 DistributedFunctionLibraryRuntime* cluster_flr,
303 string* graph_handle) {
304 Item* item = new Item;
305 Status s = InitItem(handle, gdef, graph_options, debug_options, config_proto,
306 collective_graph_key, session, cluster_flr, item);
307 if (!s.ok()) {
308 item->Unref();
309 return s;
310 }
311
312 // Inserts one item into table_.
313 {
314 mutex_lock l(mu_);
315 *graph_handle =
316 strings::Printf("%016llx", static_cast<long long>(++next_id_));
317 item->handle = *graph_handle;
318 CHECK(table_.insert({*graph_handle, item}).second);
319 }
320 return OkStatus();
321}
322
323Status GraphMgr::Deregister(const string& handle) {
324 Item* item = nullptr;
325 // Removes one item from table_.
326 {
327 mutex_lock l(mu_);
328 auto iter = table_.find(handle);
329 if (iter == table_.end()) {
330 return errors::Aborted("Graph handle is not found: ", handle,
331 ". Possibly, this worker just restarted.");
332 }
333 item = iter->second;
334 table_.erase(iter);
335 }
336 item->Unref();
337 return OkStatus();
338}
339
340Status GraphMgr::DeregisterAll() {
341 std::vector<Item*> items;
342 // Removes all items from table_.
343 {
344 mutex_lock l(mu_);
345 for (const auto& entry : table_) {
346 items.push_back(entry.second);
347 }
348 table_.clear();
349 }
350 for (auto item : items) {
351 item->Unref();
352 }
353 return OkStatus();
354}
355
356Status GraphMgr::SendInputs(const int64_t step_id, const NamedTensors& in) {
357 Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
358 std::vector<string> keys;
359 std::vector<Tensor> tensors_to_send;
360 keys.reserve(in.size());
361 tensors_to_send.reserve(in.size());
362 size_t input_size = 0;
363 for (const auto& p : in) {
364 keys.push_back(p.first);
365 tensors_to_send.push_back(p.second);
366 input_size += p.second.AllocatedBytes();
367 }
368 metrics::RecordGraphInputTensors(input_size);
369 Status s =
370 SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send);
371 rendezvous->Unref();
372 return s;
373}
374
375Status GraphMgr::RecvOutputs(const int64_t step_id, NamedTensors* out) {
376 Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
377 Status s = RecvOutputsFromRendezvous(rendezvous, out, Rendezvous::Args());
378 rendezvous->Unref();
379 if (!s.ok()) {
380 // Failing to fetch the outputs should not be possible, so rewrite the error
381 // status to an INTERNAL error.
382 s = errors::Internal("Failed to fetch outputs for step ", step_id,
383 ". (Original error message: ", s.error_message(), ")");
384 }
385 size_t output_size = 0;
386 for (auto& p : *out) {
387 output_size += p.second.AllocatedBytes();
388 }
389 metrics::RecordGraphOutputTensors(output_size);
390 return s;
391}
392
393void GraphMgr::RecvOutputsAsync(const int64_t step_id, NamedTensors* out,
394 StatusCallback done) {
395 Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
396 std::vector<string> keys;
397 std::vector<Tensor>* received_keys = new std::vector<Tensor>;
398 keys.reserve(out->size());
399 received_keys->reserve(out->size());
400 for (const auto& p : *out) {
401 keys.push_back(p.first);
402 received_keys->push_back(p.second);
403 }
404 RecvOutputsFromRendezvousAsync(
405 rendezvous, nullptr, {}, keys, received_keys,
406 [done, rendezvous, received_keys, out, keys](const Status s) {
407 rendezvous->Unref();
408 size_t output_size = 0;
409 for (int i = 0, end = keys.size(); i < end; ++i) {
410 (*out)[keys[i]] = (*received_keys)[i];
411 output_size += (*out)[keys[i]].AllocatedBytes();
412 }
413 metrics::RecordGraphOutputTensors(output_size);
414 delete received_keys;
415 done(s);
416 });
417}
418
419void GraphMgr::ExecuteAsync(
420 const string& handle, const int64_t step_id, const ExecutorOpts& opts,
421 const NamedTensors& in, WorkerSession* session,
422 StepStatsCollector* collector, MutableRunGraphResponseWrapper* response,
423 CancellationManager* cancellation_manager,
424 CoordinationServiceAgent* coordination_service_agent, StatusCallback done) {
425 const uint64 start_time_usecs = Env::Default()->NowMicros();
426 profiler::TraceMeProducer activity(
427 // To TraceMeConsumers in ExecutorState::Process/Finish or RunGraphDone.
428 [step_id] {
429 return profiler::TraceMeEncode(
430 "RunGraph", {{"id", step_id}, {"_r", 1} /*root_event*/});
431 },
432 profiler::ContextType::kTfExecutor, step_id,
433 profiler::TraceMeLevel::kInfo);
434 // Lookup an item. Holds one ref while executing.
435 Item* item = nullptr;
436 {
437 mutex_lock l(mu_);
438 auto iter = table_.find(handle);
439 if (iter != table_.end()) {
440 item = iter->second;
441 item->Ref();
442 }
443 }
444
445 if (item == nullptr) {
446 done(errors::Aborted("Graph handle is not found: ", handle));
447 return;
448 }
449
450 CostGraphDef* cost_graph = nullptr;
451 if (response != nullptr) {
452 cost_graph = response->mutable_cost_graph();
453 if (opts.record_partition_graphs()) {
454 for (const ExecutionUnit& unit : item->units) {
455 GraphDef graph_def;
456 unit.graph->ToGraphDef(&graph_def);
457 response->AddPartitionGraph(graph_def);
458 }
459 }
460 }
461
462 RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
463 Status s = rendezvous->Initialize(session);
464 CollectiveExecutor::Handle* ce_handle =
465 item->collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey
466 ? new CollectiveExecutor::Handle(
467 worker_env_->collective_executor_mgr->FindOrCreate(step_id),
468 true)
469 : nullptr;
470 // Sends values specified by the caller.
471 size_t input_size = 0;
472 if (s.ok()) {
473 std::vector<string> keys;
474 std::vector<Tensor> tensors_to_send;
475 keys.reserve(in.size());
476 tensors_to_send.reserve(in.size());
477 for (auto& p : in) {
478 keys.push_back(p.first);
479 tensors_to_send.push_back(p.second);
480 input_size += p.second.AllocatedBytes();
481 }
482 s = SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send);
483 }
484
485 if (!s.ok()) {
486 done(s);
487 delete ce_handle;
488 item->Unref();
489 rendezvous->Unref();
490 return;
491 }
492
493 StartParallelExecutors(
494 handle, step_id, item, rendezvous, ce_handle, collector, cost_graph,
495 cancellation_manager, session, start_time_usecs,
496 coordination_service_agent,
497 [item, rendezvous, ce_handle, done, start_time_usecs, input_size,
498 step_id](const Status& s) {
499 profiler::TraceMeConsumer activity(
500 // From TraceMeProducer in GraphMgr::ExecuteAsync.
501 [step_id] {
502 return profiler::TraceMeEncode("RunGraphDone", {{"id", step_id}});
503 },
504 profiler::ContextType::kTfExecutor, step_id,
505 profiler::TraceMeLevel::kInfo);
506 done(s);
507 metrics::RecordGraphInputTensors(input_size);
508 metrics::UpdateGraphExecTime(Env::Default()->NowMicros() -
509 start_time_usecs);
510 rendezvous->Unref();
511 item->Unref();
512 delete ce_handle;
513 });
514}
515
516void GraphMgr::StartParallelExecutors(
517 const string& handle, int64_t step_id, Item* item, Rendezvous* rendezvous,
518 CollectiveExecutor::Handle* ce_handle, StepStatsCollector* collector,
519 CostGraphDef* cost_graph, CancellationManager* cancellation_manager,
520 WorkerSession* session, int64_t start_time_usecs,
521 CoordinationServiceAgent* coordination_service_agent, StatusCallback done) {
522 const int num_units = item->units.size();
523 CHECK_GE(num_units, 1);
524 ScopedStepContainer* step_container = new ScopedStepContainer(
525 step_id,
526 [this](const string& name) { device_mgr_->ClearContainers({name}); });
527 // NOTE: Transfer one ref of rendezvous and item.
528 ExecutorBarrier* barrier =
529 new ExecutorBarrier(num_units, rendezvous,
530 [this, item, collector, cost_graph, step_container,
531 done](const Status& s) {
532 BuildCostModel(item, collector, cost_graph);
533 done(s);
534 delete step_container;
535 });
536 Executor::Args args;
537 args.step_id = step_id;
538 args.rendezvous = rendezvous;
539 args.collective_executor = ce_handle ? ce_handle->get() : nullptr;
540 args.cancellation_manager = cancellation_manager;
541 args.stats_collector = collector;
542 args.step_container = step_container;
543 args.sync_on_finish = sync_on_finish_;
544 args.start_time_usecs = start_time_usecs;
545 args.coordination_service_agent = coordination_service_agent;
546
547 if (LogMemory::IsEnabled()) {
548 LogMemory::RecordStep(args.step_id, handle);
549 }
550 thread::ThreadPool* pool = worker_env_->compute_pool;
551 using std::placeholders::_1;
552 // Line below is equivalent to this code, but does one less indirect call:
553 // args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
554 auto default_runner = std::bind(&thread::ThreadPool::Schedule, pool, _1);
555 for (const auto& unit : item->units) {
556 // TODO(zhengxq): if the device picks its own threadpool, we need to assign
557 // less threads to the main compute pool by default.
558 thread::ThreadPool* device_thread_pool =
559 unit.device->tensorflow_device_thread_pool();
560 if (!device_thread_pool) {
561 args.runner = default_runner;
562 } else {
563 args.runner =
564 std::bind(&thread::ThreadPool::Schedule, device_thread_pool, _1);
565 }
566 unit.root->RunAsync(args, barrier->Get());
567 }
568}
569
570void GraphMgr::BuildCostModel(Item* item, StepStatsCollector* collector,
571 CostGraphDef* cost_graph) {
572 if (collector && !skip_cost_models_) {
573 // Build the cost model
574 std::unordered_map<string, const Graph*> device_to_graph;
575 for (const auto& unit : item->units) {
576 if (unit.build_cost_model > 0) {
577 device_to_graph[unit.device->name()] = unit.graph.get();
578 }
579 }
580 collector->BuildCostModel(&cost_model_manager_, device_to_graph);
581
582 if (cost_graph != nullptr) {
583 for (const auto& unit : item->units) {
584 cost_model_manager_.AddToCostGraphDef(unit.graph.get(), cost_graph)
585 .IgnoreError();
586 }
587 }
588 }
589}
590
591} // end namespace tensorflow
592