1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
16
17#include <algorithm>
18#include <functional>
19#include <iterator>
20#include <optional>
21#include <utility>
22
23#include "absl/container/flat_hash_map.h"
24#include "absl/memory/memory.h"
25#include "absl/strings/str_join.h"
26#include "tensorflow/core/common_runtime/device_set.h"
27#include "tensorflow/core/common_runtime/function.h"
28#include "tensorflow/core/common_runtime/function_optimization_registry.h"
29#include "tensorflow/core/common_runtime/graph_constructor.h"
30#include "tensorflow/core/common_runtime/optimization_registry.h"
31#include "tensorflow/core/common_runtime/partitioning_utils.h"
32#include "tensorflow/core/common_runtime/placer.h"
33#include "tensorflow/core/common_runtime/process_util.h"
34#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
35#include "tensorflow/core/common_runtime/rendezvous_util.h"
36#include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
37#include "tensorflow/core/common_runtime/single_threaded_executor.h"
38#include "tensorflow/core/framework/cancellation.h"
39#include "tensorflow/core/framework/function.h"
40#include "tensorflow/core/framework/graph_to_functiondef.h"
41#include "tensorflow/core/framework/metrics.h"
42#include "tensorflow/core/framework/op_kernel.h"
43#include "tensorflow/core/framework/tensor.h"
44#include "tensorflow/core/framework/types.h"
45#include "tensorflow/core/framework/types.pb.h"
46#include "tensorflow/core/graph/graph.h"
47#include "tensorflow/core/graph/graph_node_util.h"
48#include "tensorflow/core/graph/graph_partition.h"
49#include "tensorflow/core/lib/core/errors.h"
50#include "tensorflow/core/lib/gtl/cleanup.h"
51#include "tensorflow/core/lib/gtl/inlined_vector.h"
52#include "tensorflow/core/lib/gtl/map_util.h"
53#include "tensorflow/core/lib/random/random.h"
54#include "tensorflow/core/platform/blocking_counter.h"
55#include "tensorflow/core/platform/notification.h"
56#include "tensorflow/core/util/device_name_utils.h"
57#include "tensorflow/core/util/dump_graph.h"
58#include "tensorflow/core/util/ptr_util.h"
59#include "tensorflow/core/util/reffed_status_callback.h"
60#include "tensorflow/tsl/platform/statusor.h"
61#if !defined(IS_MOBILE_PLATFORM)
62#include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
63#endif // IS_MOBILE_PLATFORM
64
65namespace tensorflow {
66
67const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null";
68
69void ProcessFunctionLibraryRuntime::FunctionData::DistributedInit(
70 DistributedFunctionLibraryRuntime* parent, const string& function_name,
71 const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
72 const FunctionLibraryRuntime::InstantiateOptions& options,
73 FunctionLibraryRuntime::DoneCallback done) {
74 {
75 mutex_lock l(mu_);
76 is_cross_process_ = true;
77 if (init_started_) {
78 init_done_.WaitForNotification();
79 done(init_result_);
80 return;
81 }
82 init_started_ = true;
83 }
84 parent->Instantiate(function_name, lib_def, attrs, options, &local_handle_,
85 [this, done](const Status& s) {
86 init_done_.Notify();
87 done(s);
88 });
89}
90
91ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
92 const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
93 int graph_def_version, const FunctionLibraryDefinition* lib_def,
94 const OptimizerOptions& optimizer_options,
95 thread::ThreadPool* default_thread_pool,
96 DistributedFunctionLibraryRuntime* parent,
97 const SessionMetadata* session_metadata,
98 Rendezvous::Factory rendezvous_factory)
99 : parent_(parent),
100 env_(env),
101 config_(config ? absl::make_optional(*config) : absl::nullopt),
102 device_mgr_(device_mgr),
103 lib_def_(lib_def),
104 default_thread_pool_(default_thread_pool),
105 flr_map_(new std::unordered_map<Device*,
106 std::unique_ptr<FunctionLibraryRuntime>>),
107 next_handle_(0),
108 session_metadata_(session_metadata),
109 rendezvous_factory_(std::move(rendezvous_factory)),
110 optimizer_options_(optimizer_options),
111 graph_def_version_(graph_def_version) {
112 if (device_mgr == nullptr) {
113 (*flr_map_)[nullptr] = NewFunctionLibraryRuntime(
114 nullptr, env, config_ ? &(*config_) : nullptr, nullptr,
115 graph_def_version, lib_def_, default_thread_pool, optimizer_options,
116 session_metadata_, this);
117 return;
118 }
119 InitializeDeviceAndFlr();
120}
121
122/* static */
123Status ProcessFunctionLibraryRuntime::SendTensors(
124 const string& source_device, const string& target_device,
125 const string& key_prefix, int64_t src_incarnation,
126 gtl::ArraySlice<Tensor> tensors_to_send, DeviceContext* device_context,
127 const std::vector<AllocatorAttributes>& alloc_attrs,
128 RendezvousInterface* rendezvous) {
129 std::vector<string> keys;
130 for (int i = 0; i < tensors_to_send.size(); ++i) {
131 string name = strings::StrCat(key_prefix, i);
132 string key = Rendezvous::CreateKey(source_device, src_incarnation,
133 target_device, name, FrameAndIter(0, 0));
134 keys.push_back(key);
135 }
136 TF_RETURN_IF_ERROR(SendTensorsToRendezvous(
137 rendezvous, device_context, alloc_attrs, keys, tensors_to_send));
138 return OkStatus();
139}
140
141/* static */
142void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
143 const string& source_device, const string& target_device,
144 const string& key_prefix, int64_t src_incarnation, int64_t num_tensors,
145 DeviceContext* device_context,
146 const std::vector<AllocatorAttributes>& alloc_attrs,
147 RendezvousInterface* rendezvous, std::vector<Tensor>* received_tensors,
148 StatusCallback done) {
149 std::vector<string> keys;
150 for (int64_t i = 0; i < num_tensors; ++i) {
151 string name = strings::StrCat(key_prefix, i);
152 string key = Rendezvous::CreateKey(source_device, src_incarnation,
153 target_device, name, FrameAndIter(0, 0));
154 keys.push_back(key);
155 }
156 RecvOutputsFromRendezvousAsync(rendezvous, device_context, alloc_attrs, keys,
157 received_tensors, std::move(done));
158}
159
160Status ProcessFunctionLibraryRuntime::GetRetTypes(
161 FunctionLibraryRuntime::Handle h, DataTypeVector* ret_types) {
162 FunctionLibraryRuntime* flr = nullptr;
163 {
164 tf_shared_lock l(mu_);
165 auto miter = mdevice_data_.find(h);
166 if (miter != mdevice_data_.end()) {
167 *ret_types = miter->second->ret_types_;
168 return OkStatus();
169 }
170 auto fiter = function_data_.find(h);
171 if (fiter != function_data_.end()) {
172 flr = GetFLR(fiter->second->target_device());
173 }
174 }
175 if (flr != nullptr) {
176 return flr->GetRetTypes(h, ret_types);
177 }
178 return errors::InvalidArgument("Handle ", h, " not found.");
179}
180
181Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation(
182 const string& device_name, int64_t* incarnation) const {
183 FunctionLibraryRuntime* flr = GetFLR(device_name);
184 if (flr == nullptr) {
185 return errors::InvalidArgument("Device name: ", device_name, " not found.");
186 }
187 *incarnation = flr->device()->attributes().incarnation();
188 return OkStatus();
189}
190
191Status ProcessFunctionLibraryRuntime::GetDeviceContext(
192 const string& device_name, DeviceContext** device_context) const {
193 *device_context = nullptr;
194 FunctionLibraryRuntime* flr = GetFLR(device_name);
195 if (flr == nullptr) {
196 return errors::InvalidArgument("Device name: ", device_name, " not found.");
197 }
198 Device* device = flr->device();
199 string device_type = device->parsed_name().type;
200 if (device_type == "CPU" || device_type == "TPU_SYSTEM") {
201 // "TPU_SYSTEM" indicates that `device` is a CPU.
202 return OkStatus();
203 }
204
205 if (device->IsRemoteCallAllowed()) {
206 auto* dev_info = flr->device()->tensorflow_accelerator_device_info();
207 if (dev_info) {
208 *device_context = dev_info->default_context;
209 return OkStatus();
210 }
211 }
212
213 return errors::Internal("Device type: ", device_type,
214 " is currently unsupported for remote ",
215 "function executions");
216}
217
218void ProcessFunctionLibraryRuntime::InitializeDeviceAndFlr() {
219 // Reset device_set_ by one of the two following scenarios:
220 // 1) Both cluster-FLR and its remote_device_mgr is available: include local
221 // devices (if any) from the local device_mgr_ as Device type, and include
222 // remote devices from cluster's remote_device_mgr as RemoteDevice type.
223 // 2) Include local devices from the local device_mgr_.
224 // In both scenarios, no device is added more than one times.
225 mutex_lock l(mu_);
226 device_set_ = std::make_shared<DeviceSet>();
227 if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
228 for (auto d : parent_->remote_device_mgr()->ListDevices()) {
229 Device* device = nullptr;
230 if (device_mgr_->LookupDevice(d->name(), &device) == OkStatus()) {
231 // If this device exists in device_mgr, i.e., a local device,
232 // add this device from the instance included in device_mgr_
233 device_set_->AddDevice(device);
234 } else {
235 device_set_->AddDevice(d);
236 }
237 }
238 } else {
239 for (auto d : device_mgr_->ListDevices()) {
240 device_set_->AddDevice(d);
241 }
242 }
243
244 // Update flr_map_ by adding new devices
245 for (Device* d : device_mgr_->ListDevices()) {
246 if ((*flr_map_)[d] == nullptr) {
247 (*flr_map_)[d] = NewFunctionLibraryRuntime(
248 device_mgr_, env_, config_ ? &(*config_) : nullptr, d,
249 graph_def_version_, lib_def_, default_thread_pool_,
250 optimizer_options_, session_metadata_, this);
251 }
252 }
253}
254
255FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
256 const string& device_name) const {
257 Device* device = nullptr;
258 if (device_name != kDefaultFLRDevice) {
259 if (!device_mgr_->LookupDevice(device_name, &device).ok()) {
260 VLOG(4) << "Could not find device: " << device_name;
261 return nullptr;
262 }
263 }
264 const auto& iter = flr_map_->find(device);
265 if (iter == flr_map_->end()) {
266 VLOG(1) << "Could not find device: " << device_name
267 << "in the local process.";
268 return nullptr;
269 }
270 return iter->second.get();
271}
272
273FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
274 const string& function_key, const string& device_name,
275 FunctionLibraryRuntime::LocalHandle local_handle) {
276 mutex_lock l(mu_);
277 return AddHandleLocked(function_key, device_name, local_handle);
278}
279
280FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandleLocked(
281 const string& function_key, const string& device_name,
282 FunctionLibraryRuntime::LocalHandle local_handle) {
283 auto h = next_handle_;
284 function_data_[h] =
285 std::make_unique<FunctionData>(device_name, local_handle, function_key);
286 table_[function_key] = h;
287 next_handle_++;
288 return h;
289}
290
291FunctionLibraryRuntime::Handle
292ProcessFunctionLibraryRuntime::AddMultiDeviceHandle(
293 std::unique_ptr<MultiDeviceFunctionData> data, const string& function_key) {
294 mutex_lock l(mu_);
295 auto h = next_handle_;
296 mdevice_data_[h] = std::move(data);
297 table_[function_key] = h;
298 next_handle_++;
299 return h;
300}
301
302bool ProcessFunctionLibraryRuntime::HasMultiDeviceHandle(
303 FunctionLibraryRuntime::Handle handle) const {
304 bool multi_device;
305 {
306 tf_shared_lock l(mu_);
307 multi_device = mdevice_data_.find(handle) != mdevice_data_.end();
308 }
309 return multi_device;
310}
311
312FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle(
313 const string& function_key) const {
314 tf_shared_lock l(mu_);
315 return gtl::FindWithDefault(table_, function_key, kInvalidHandle);
316}
317
318FunctionLibraryRuntime::LocalHandle
319ProcessFunctionLibraryRuntime::GetHandleOnDevice(
320 const string& device_name, FunctionLibraryRuntime::Handle handle,
321 bool include_multi_device) const {
322 tf_shared_lock l(mu_);
323
324 auto miter = mdevice_data_.find(handle);
325 if (miter != mdevice_data_.end()) {
326 if (!include_multi_device) return kInvalidLocalHandle;
327
328 const MultiDeviceFunctionData& data = *miter->second;
329 if (data.glue_.size() != 1) return kInvalidLocalHandle;
330
331 const auto& pair = *data.glue_.begin();
332 const string& func_device_name = pair.first;
333 const ComponentFunctionData& component_data = pair.second;
334 if (func_device_name != device_name) return kInvalidLocalHandle;
335
336 // Replace the given handle with the handle for the single component
337 // function.
338 handle = component_data.handle;
339 }
340
341 auto iter = function_data_.find(handle);
342 if (iter == function_data_.end()) {
343 return kInvalidLocalHandle;
344 }
345 FunctionData* function_data = iter->second.get();
346 if (function_data->target_device() != device_name) {
347 return kInvalidLocalHandle;
348 }
349 return function_data->local_handle();
350}
351
352string ProcessFunctionLibraryRuntime::GetDeviceName(
353 FunctionLibraryRuntime::Handle handle) const {
354 tf_shared_lock l(mu_);
355 auto iter = function_data_.find(handle);
356 CHECK(iter != function_data_.end());
357 FunctionData* function_data = iter->second.get();
358 return function_data->target_device();
359}
360
361ProcessFunctionLibraryRuntime::MultiDeviceFunctionData*
362ProcessFunctionLibraryRuntime::IsMultiDevice(
363 FunctionLibraryRuntime::Handle handle) const {
364 tf_shared_lock l(mu_);
365 const auto& it = mdevice_data_.find(handle);
366 if (it != mdevice_data_.end()) {
367 return it->second.get();
368 }
369 return nullptr;
370}
371
372namespace {
373// Sets `group` to the first colocation group specified in `node`. If no
374// group is specified, does not touch `group`.
375void GetColocationGroup(const Node* node, string* group) {
376 // We hoist the conversion from C-style string literal to string here,
377 // so that we can avoid the many repeated calls to strlen().
378 static const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
379 const AttrValue* attr_value =
380 node->attrs().Find(kColocationAttrNameStringPiece);
381 if (attr_value != nullptr && attr_value->has_list() &&
382 attr_value->list().s_size() > 0) {
383 *group = attr_value->list().s(0);
384 }
385}
386
387const string* AssignedOrRequestedDeviceName(const Node& node) {
388 if (node.has_assigned_device_name()) {
389 return &node.assigned_device_name();
390 }
391 return &node.requested_device();
392}
393
394Status SetArgShape(const std::unordered_map<int, DtypeAndPartialTensorShape>&
395 input_resource_dtypes_and_shapes,
396 const std::vector<Node*>& arg_nodes) {
397 for (Node* n : arg_nodes) {
398 int index;
399 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
400 DataType dtype;
401 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
402 if (dtype == DT_RESOURCE) {
403 auto dtype_and_shape_iter = input_resource_dtypes_and_shapes.find(index);
404 if (dtype_and_shape_iter != input_resource_dtypes_and_shapes.end()) {
405 AttrValue dtype_attr_value;
406 dtype_attr_value.mutable_list()->add_type(
407 dtype_and_shape_iter->second.dtype);
408 n->AddAttr("_handle_dtypes", dtype_attr_value);
409 TensorShapeProto shape_proto;
410 dtype_and_shape_iter->second.shape.AsProto(&shape_proto);
411 AttrValue shape_attr_value;
412 *shape_attr_value.mutable_list()->add_shape() = shape_proto;
413 n->AddAttr("_handle_shapes", shape_attr_value);
414 }
415 }
416 }
417 return OkStatus();
418}
419
420// Returns the local tensors referred by `args`.
421std::vector<Tensor> GetLocalArgs(gtl::ArraySlice<FunctionArg> args) {
422 std::vector<Tensor> tensors;
423 for (const auto& arg : args) {
424 if (arg.index() == 0) {
425 tensors.push_back(absl::get<Tensor>(arg));
426 }
427 }
428 return tensors;
429}
430
431// Update the done callback to push Tensors in `tensors` into `rets`.
432FunctionLibraryRuntime::DoneCallback TensorsToFunctionRetsDoneCallback(
433 std::vector<FunctionRet>* rets, std::vector<Tensor>* tensors,
434 FunctionLibraryRuntime::DoneCallback done) {
435 return [rets, tensors, done = std::move(done)](const Status& s) {
436 if (s.ok()) {
437 for (const auto& t : *tensors) {
438 rets->push_back(t);
439 }
440 }
441 delete tensors;
442 done(s);
443 };
444}
445
446// Push Tensors in `function_rets` into `tensors`.
447Status FunctionRetsToTensors(const std::vector<FunctionRet>* function_rets,
448 std::vector<Tensor>* tensors) {
449 for (const auto& ret : *function_rets) {
450 if (ret.index() != 0) {
451 return errors::Internal(
452 "Expect a Tensor as a function output but got a TensorShape.");
453 }
454 tensors->push_back(absl::get<Tensor>(ret));
455 }
456 return OkStatus();
457}
458
459} // anonymous namespace
460
461Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
462 const std::vector<string>& input_devices,
463 const std::vector<string>& output_devices, const DeviceSet& device_set,
464 const std::vector<Node*>& arg_nodes, const std::vector<Node*>& ret_nodes,
465 const FunctionLibraryDefinition* lib_def, Device* default_device) {
466 // If output_devices are not specified, we want to set the output device
467 // based on the device of the output producing node. The output producing
468 // node can be an arg node because functions can simply return their
469 // arguments. To make sure that the output producing nodes have assigned
470 // devices, we assign them to arguments first.
471 for (Node* node : arg_nodes) {
472 const AttrValue* attr_value;
473 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
474 int64_t index = attr_value->i();
475 node->set_assigned_device_name(input_devices[index]);
476 }
477
478 for (Node* node : ret_nodes) {
479 if (output_devices.empty()) {
480 DataType dtype;
481 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype));
482
483 VLOG(3) << "Trying to determine device for node " << node->name()
484 << "[T=" << DataTypeString(dtype) << "]";
485
486 // If output_devices are empty, the node producing retval
487 // must have explicitly assigned device or a colocation constraint
488 // to a node with explicitly assigned device.
489 for (const auto& it : node->in_edges()) {
490 if (it->IsControlEdge()) continue;
491
492 Node* src_node = it->src();
493 const string* src_device = AssignedOrRequestedDeviceName(*src_node);
494 string colocation_group = "";
495 GetColocationGroup(src_node, &colocation_group);
496 VLOG(3) << "Considering src: " << src_node->name()
497 << " src_device: " << *src_device
498 << " colo group: " << colocation_group;
499 while (src_device->empty() && colocation_group.empty() &&
500 src_node->IsIdentity()) {
501 // Only follows the real data input of Identity, not control edges.
502 Node* input_node;
503 TF_RETURN_IF_ERROR(src_node->input_node(0, &input_node));
504 src_node = input_node;
505
506 src_device = AssignedOrRequestedDeviceName(*src_node);
507 GetColocationGroup(src_node, &colocation_group);
508 VLOG(3) << "Considering src: " << src_node->name()
509 << " src_device: " << *src_device
510 << " colo group: " << colocation_group;
511 }
512
513 // If resource is produced by a function call node, we can't trust
514 // source node device assignment, because multi-device functions can
515 // return resource placed on multiple devices. In such case we leave
516 // retval device assignment empty, and rely on placer to infer correct
517 // assignment based on actual output device.
518 const bool can_use_src_node_device =
519 !(dtype == DT_RESOURCE && IsFunctionCall(*lib_def, *src_node));
520
521 if (!colocation_group.empty()) {
522 AttrValue::ListValue colo_attr;
523 colo_attr.add_s(colocation_group);
524 std::vector<string> colo_slice = {colocation_group};
525 node->AddAttr(kColocationAttrName, colo_slice);
526 } else if (!src_device->empty() && can_use_src_node_device) {
527 // Do not copy device from src node for variants, unless it is a no-op
528 // forward from input to output. This gets handled in
529 // colocation_graph.cc which has special logic for correctly placing
530 // _Retvals for various variant types.
531 if (dtype == DT_VARIANT && !src_node->IsArg()) {
532 continue;
533 }
534 // src_device can be a partially specified device. Find the
535 // matching device in the device_set.
536 DeviceNameUtils::ParsedName parsed;
537 if (!DeviceNameUtils::ParseFullName(*src_device, &parsed)) {
538 return errors::InvalidArgument(
539 "Failed to parse explicit device specification ", *src_device);
540 }
541 std::vector<Device*> matching_devices;
542 device_set.FindMatchingDevices(parsed, &matching_devices);
543 if (matching_devices.empty()) {
544 if (default_device != nullptr) {
545 matching_devices.push_back(default_device);
546 } else {
547 return errors::InvalidArgument(
548 "Unable to find any devices for spec ", *src_device);
549 }
550 } else if (matching_devices.size() != 1) {
551 bool on_same_task = true;
552 for (int i = 1; i < matching_devices.size(); ++i) {
553 if (!DeviceNameUtils::IsSameAddressSpace(
554 matching_devices.at(0)->parsed_name(),
555 matching_devices.at(i)->parsed_name())) {
556 on_same_task = false;
557 break;
558 }
559 }
560 // If the src node of an output is assigned to a address space (e.g.
561 // py_func), rely on placer to assign a device to the output.
562 if (on_same_task) {
563 continue;
564 }
565 // Compare with default_device if it has a narrower scope matching
566 // requested device.
567 if (default_device != nullptr) {
568 int colocated_on_default_device = 0;
569 for (int i = 0; i < matching_devices.size(); ++i) {
570 if (DeviceNameUtils::IsSameAddressSpace(
571 default_device->parsed_name(),
572 matching_devices.at(i)->parsed_name())) {
573 colocated_on_default_device++;
574 }
575 }
576 // Continue to raise error if multiple colocated devices are
577 // found.
578 if (colocated_on_default_device == 1) {
579 continue;
580 }
581 }
582 // Convert a vector of devices to a string.
583 // Using absl::StrJoin did not work in Android builds.
584 string devices = "[";
585 for (Device* device : matching_devices) {
586 devices.append(device->name());
587 devices.append(", ");
588 }
589 if (devices.size() > 2) {
590 devices.resize(devices.size() - 2);
591 }
592 devices.append("]");
593
594 return errors::InvalidArgument(
595 *src_device,
596 "When FunctionLibraryRuntime::Options.output_devices are "
597 "not specified for a multi-device function, the device "
598 "specification on the output node must match exactly one "
599 "device. Matched devices are ",
600 devices);
601 }
602 VLOG(3) << "Setting output device to " << matching_devices[0]->name()
603 << " for node " << SummarizeNode(*node);
604 node->set_assigned_device_name(matching_devices[0]->name());
605 } else if (!src_device->empty() && !can_use_src_node_device) {
606 VLOG(3) << "Did not set device for a resource output node "
607 << SummarizeNode(*node);
608 }
609 }
610 } else {
611 const AttrValue* attr_value;
612 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
613 int64_t index = attr_value->i();
614 // output_devices size is checked in InstantiateMultiDevice
615 DCHECK_GT(output_devices.size(), index);
616 VLOG(3) << "Setting output device to " << output_devices[index]
617 << " for return at index " << index;
618 node->set_assigned_device_name(output_devices[index]);
619 }
620 }
621 return OkStatus();
622}
623
624namespace {
625
626Status ValidateNoListArguments(
627 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, const char* arg_type,
628 const string& function_name) {
629 for (const OpDef::ArgDef& arg : args) {
630 if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) {
631 return errors::InvalidArgument(
632 "Function ", function_name, " has an ", arg_type, " named \"",
633 arg.name(),
634 "\" that is a list of tensors."
635 " Multi-device functions support only single-tensor inputs "
636 " and outputs");
637 }
638 }
639 return OkStatus();
640}
641
642Status ValidateMultiDeviceOptions(
643 const FunctionDef& fdef,
644 const FunctionLibraryRuntime::InstantiateOptions& options) {
645 const OpDef& signature = fdef.signature();
646 // Multi-device functions currently do not support list inputs or outputs.
647 TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.input_arg(), "input",
648 signature.name()));
649 TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.output_arg(), "output",
650 signature.name()));
651 if (fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
652 fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b()) {
653 return errors::Unimplemented(
654 "Function '", signature.name(), "' has `",
655 FunctionLibraryDefinition::kIntsOnDeviceAttr,
656 "` attribute set. This attribute is not currently supported by "
657 "multi-device functions.");
658 }
659 if (options.input_devices.size() != signature.input_arg_size()) {
660 return errors::InvalidArgument(
661 "InstantiateOptions.input_devices must have the same length "
662 "as the number of arguments: input_devices length = ",
663 options.input_devices.size(),
664 " number of arguments = ", signature.input_arg_size());
665 }
666 if (!options.output_devices.empty() &&
667 options.output_devices.size() != signature.output_arg_size()) {
668 return errors::InvalidArgument(
669 "InstantiateOptions.output_devices must either be empty or have the "
670 "same length as the number of arguments: output_devices length = ",
671 options.output_devices.size(),
672 " number of arguments = ", signature.output_arg_size());
673 }
674 return OkStatus();
675}
676
677} // anonymous namespace
678
679ProcessFunctionLibraryRuntime::AsyncAttributes::Summary
680ProcessFunctionLibraryRuntime::AsyncAttributes::Summarize(const Graph* graph) {
681 bool has_send_op = false;
682 bool has_recv_op = false;
683 bool has_unsafe_op = false;
684 for (const Node* node : graph->nodes()) {
685 if (node->IsSend() || node->IsHostSend()) {
686 has_send_op = true;
687 }
688 if (node->IsRecv() || node->IsHostRecv()) {
689 has_recv_op = true;
690 }
691 if (!ValidateOpIsSafeForSyncExecution(*node,
692 allow_control_flow_sync_execution())
693 .ok()) {
694 has_unsafe_op = true;
695 }
696 }
697 // (1) Anything completely unsupported?
698 if (has_unsafe_op) {
699 metrics::IncrementTestCounter("subgraph_async_summary", "unsafe_op");
700 return AsyncAttributes::kAsyncRequired;
701 }
702 // (2) That only leaves send/recv. If neither, then it's safe.
703 if (!has_send_op && !has_recv_op) {
704 metrics::IncrementTestCounter("subgraph_async_summary", "safe_for_sync");
705 return AsyncAttributes::kSafeForSync;
706 }
707 // (3) If each subgraph has only send or only recv, then it's possible to
708 // order them to run sequentially without deadlock.
709 if (has_send_op && !has_recv_op) {
710 metrics::IncrementTestCounter("subgraph_async_summary", "send_only");
711 return AsyncAttributes::kSendOnly;
712 }
713 if (has_recv_op && !has_send_op) {
714 metrics::IncrementTestCounter("subgraph_async_summary", "recv_only");
715 return AsyncAttributes::kRecvOnly;
716 }
717 // Otherwise, assume it's unsupported.
718 metrics::IncrementTestCounter("subgraph_async_summary", "other");
719 return AsyncAttributes::kAsyncRequired;
720}
721
722Status GetGraphAndArgRets(
723 const string& function_name, AttrSlice attrs, const FunctionDef* fdef,
724 const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph,
725 std::vector<Node*>* arg_nodes, std::vector<Node*>* ret_nodes,
726 std::vector<string>* ret_node_names, DataTypeVector* ret_types,
727 std::vector<string>* control_ret_node_names) {
728 std::unique_ptr<FunctionBody> fbody;
729 // TODO(iga): FunctionDefToBodyHelper copies fdef. Avoid this copy.
730 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, attrs, lib_def, &fbody));
731 if (!fbody) {
732 LOG(ERROR) << "Failed to get FunctionBody for \"" << function_name << "\"";
733 return errors::Internal("Failed to construct FunctionBody for ",
734 function_name);
735 }
736 *graph = std::unique_ptr<Graph>(fbody->graph);
737 arg_nodes->reserve(fbody->arg_nodes.size());
738 std::copy(fbody->arg_nodes.begin(), fbody->arg_nodes.end(),
739 std::back_inserter(*arg_nodes));
740 ret_nodes->reserve(fbody->ret_nodes.size());
741 std::copy(fbody->ret_nodes.begin(), fbody->ret_nodes.end(),
742 std::back_inserter(*ret_nodes));
743 fbody->graph = nullptr;
744 ret_node_names->reserve(fbody->ret_nodes.size());
745 for (const Node* node : fbody->ret_nodes) {
746 ret_node_names->push_back(node->name());
747 }
748 for (const auto& ret_type : fbody->ret_types) {
749 ret_types->push_back(ret_type);
750 }
751 control_ret_node_names->reserve(fbody->control_ret_nodes.size());
752 for (const Node* node : fbody->control_ret_nodes) {
753 control_ret_node_names->push_back(node->name());
754 }
755 return OkStatus();
756}
757
758StatusOr<ProcessFunctionLibraryRuntime::OptimizedFunctionGraphInfo>
759ProcessFunctionLibraryRuntime::OptimizeFunctionGraph(
760 const string& function_name, AttrSlice attrs,
761 const FunctionLibraryRuntime::InstantiateOptions& options,
762 const std::shared_ptr<DeviceSet>& dev_set) {
763 const FunctionLibraryDefinition* lib_def =
764 options.lib_def == nullptr ? lib_def_ : options.lib_def;
765
766 const FunctionDef* fdef = lib_def->Find(function_name);
767 if (fdef == nullptr) {
768 return errors::InvalidArgument("Failed to find function \"", function_name,
769 "\" in function library: ", lib_def);
770 }
771
772 TF_RETURN_IF_ERROR(ValidateMultiDeviceOptions(*fdef, options));
773
774 std::unique_ptr<Graph> graph;
775 std::vector<Node*> arg_nodes, ret_nodes;
776 std::vector<string> ret_node_names;
777 DataTypeVector ret_types;
778 std::vector<string> control_ret_node_names;
779
780 TF_RETURN_IF_ERROR(GetGraphAndArgRets(
781 function_name, attrs, fdef, lib_def, &graph, &arg_nodes, &ret_nodes,
782 &ret_node_names, &ret_types, &control_ret_node_names));
783
784 GraphDef graph_def;
785 graph->ToGraphDef(&graph_def);
786 FunctionLibraryDefinition reachable_lib_def =
787 lib_def->ReachableDefinitions(graph_def);
788 *graph_def.mutable_library() = reachable_lib_def.ToProto();
789 if (options.graph_collector != nullptr) {
790 options.graph_collector->CollectRawGraph(graph_def);
791 }
792
793 Device* default_device = nullptr;
794 if (options.default_device_to_target && !options.target.empty()) {
795 // Make the `target` device the default device if nothing else is hard
796 // coded. This allows the same function definition to be specialized to
797 // different devices depending on the `PartitionedCallOp` device.
798 FunctionLibraryRuntime* flr = GetFLR(options.target);
799 if (flr == nullptr) {
800 return errors::InvalidArgument(
801 "Cannot instantiate multi-device function with target device ",
802 options.target);
803 }
804 default_device = flr->device();
805 }
806
807 // Mark and assign device for each node in the graph to be compiled by
808 // specified device.
809 if (!options.xla_compile_device_type.empty()) {
810 for (Node* node : graph->op_nodes()) {
811 node->AddAttr("_xla_compile_device_type",
812 options.xla_compile_device_type);
813 if (default_device) {
814 node->set_assigned_device_name(default_device->name());
815 }
816 }
817 }
818
819 TF_RETURN_IF_ERROR(
820 SetArgShape(options.input_resource_dtypes_and_shapes, arg_nodes));
821 TF_RETURN_IF_ERROR(PinArgsAndRets(
822 options.input_devices, options.output_devices, *dev_set, arg_nodes,
823 ret_nodes, lib_def_,
824 options.config_proto.allow_soft_placement() ? default_device : nullptr));
825
826 // The runtime shouldn't depend on duplication between the function library
827 // owned by the graph and the one owned by the runtime. To ensure this, for
828 // now we ensure that the graph function library is empty and the runtime
829 // library receives the query from LookUps on the graph function library.
830 graph->mutable_flib_def()->set_default_registry(&reachable_lib_def);
831 graph->mutable_flib_def()->Clear();
832
833 // Do not run function/graph optimization passes for component functions,
834 // since they have already processed the main function.
835 const bool should_run_optimization_passes = !options.is_component_function;
836 if (!should_run_optimization_passes) {
837 VLOG(1) << "Skipping function/graph optimization passes when instantiating "
838 "component function "
839 << function_name;
840 }
841
842 // Mapping from a function body node name to the control output name.
843 std::unordered_map<string, string> node_name_to_control_ret;
844
845 bool control_rets_updated = false;
846 if (should_run_optimization_passes) {
847 TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run(
848 *dev_set, options.config_proto, &graph, &reachable_lib_def,
849 &control_ret_node_names, &control_rets_updated));
850 }
851
852 if (control_rets_updated) {
853 // Function graph pass may have resulted in different nodes/node names for
854 // control rets.
855 for (const auto& control_ret : control_ret_node_names) {
856 node_name_to_control_ret.emplace(control_ret, control_ret);
857 }
858 } else {
859 for (const auto& control_ret : fdef->control_ret()) {
860 node_name_to_control_ret.emplace(control_ret.second, control_ret.first);
861 }
862 }
863
864 GraphOptimizationPassOptions optimization_options;
865 // TODO(iga): Thread other relevant options from SessionOptions.
866 SessionOptions session_options;
867 session_options.env = env_;
868 session_options.config = options.config_proto;
869 optimization_options.session_options = &session_options;
870 optimization_options.graph = &graph;
871 optimization_options.flib_def = &reachable_lib_def;
872 optimization_options.device_set = dev_set.get();
873 optimization_options.is_function_graph = true;
874 std::vector<CompositeDevice*> composite_devices;
875 {
876 tf_shared_lock l(mu_);
877 for (auto* d : composite_devices_) composite_devices.push_back(d);
878 }
879 optimization_options.composite_devices = &composite_devices;
880 optimization_options.default_function_device = default_device;
881 optimization_options.function_def = fdef;
882 optimization_options.shape_inference_on_tfe_dialect_import =
883 options.shape_inference_on_tfe_dialect_import;
884
885 DumpGraph("Before running PRE_PLACEMENT passes", graph.get());
886 if (should_run_optimization_passes) {
887 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
888 OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
889 }
890
891 // TODO(b/124993244): Smartly merge options in nested defuns, and raise
892 // exceptions/warnings in case where nested function call options are ignored.
893 DumpGraph("Before calling Placer", graph.get());
894 Placer placer(graph.get(), function_name, optimization_options.flib_def,
895 dev_set.get(), default_device,
896 options.config_proto.allow_soft_placement(),
897 options.config_proto.log_device_placement());
898 TF_RETURN_IF_ERROR(placer.Run());
899
900 DumpGraph("Before running POST_PLACEMENT passes", graph.get());
901 if (should_run_optimization_passes) {
902 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
903 OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
904 }
905
906 Device* cpu_device;
907 TF_RETURN_IF_ERROR(device_mgr_->LookupDevice("CPU:0", &cpu_device));
908
909 if (options.optimize_graph_fn) {
910 DumpGraph("Before running graph optimization fn", graph.get());
911 Status status = options.optimize_graph_fn(
912 std::move(ret_node_names), std::move(control_ret_node_names),
913 &reachable_lib_def, *dev_set, cpu_device, &graph);
914 if (!status.ok()) {
915 LOG(WARNING) << "Ignoring multi-device function optimization failure: "
916 << status.ToString();
917 }
918 DumpGraph("After optimization", graph.get());
919 }
920
921 DumpGraph("Before running POST_REWRITE_FOR_EXEC passes", graph.get());
922 if (should_run_optimization_passes) {
923 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
924 OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
925 }
926
927 graph->mutable_flib_def()->set_default_registry(nullptr);
928 graph->mutable_flib_def()->Clear();
929 return OptimizedFunctionGraphInfo{
930 std::move(graph), std::move(reachable_lib_def), node_name_to_control_ret,
931 std::move(ret_types), ret_nodes.size()};
932}
933
934Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
935 const string& function_name, AttrSlice attrs,
936 const FunctionLibraryRuntime::InstantiateOptions& options,
937 FunctionLibraryRuntime::Handle* handle) {
938 // Check if this function has already been instantiated.
939 const string& function_key = Canonicalize(function_name, attrs, options);
940
941 {
942 mutex_lock l(mu_);
943 const auto& it = table_.find(function_key);
944 if (it != table_.end()) {
945 *handle = it->second;
946 ++mdevice_data_[*handle]->instantiation_counter_;
947 return OkStatus();
948 }
949 }
950
951 VLOG(1) << "Instantiating MultiDevice function \"" << function_name
952 << "\" on default device \"" << options.target << "\"";
953 if (VLOG_IS_ON(3)) {
954 int index = 0;
955 VLOG(3) << "Requested input devices:";
956 for (const string& device : options.input_devices) {
957 VLOG(3) << " [input " << index++ << "] " << device;
958 }
959 index = 0;
960 VLOG(3) << "Requested output devices:";
961 for (const string& device : options.output_devices) {
962 VLOG(3) << " [output " << index++ << "] " << device;
963 }
964 }
965
966 const std::shared_ptr<DeviceSet> dev_set = device_set();
967 const uint64 optimization_start_time_usecs = Env::Default()->NowMicros();
968 TF_ASSIGN_OR_RETURN(
969 auto optimized_graph_info,
970 OptimizeFunctionGraph(function_name, attrs, options, dev_set));
971
972 auto& graph = optimized_graph_info.graph;
973 graph->mutable_flib_def()->set_default_registry(
974 &(optimized_graph_info.lib_def));
975
976 // Expand the nodes assigned to a CompositeDevice before graph partition to
977 // avoid generating a subgraph on a virtual device for execution.
978 // This transformation should happen as late as possible, in order to run as
979 // more graph optimization passes (e.g. PRE_PLACEMENT, PLACER,
980 // POST_PLACEMENT, POST_REWRITE_FOR_EXEC) on a smaller graph as possible.
981 TF_RETURN_IF_ERROR(ReplicatePerReplicaNodesInFunctionGraph(
982 options.composite_devices, graph.get()));
983
984 const FunctionLibraryDefinition* lib_def =
985 options.lib_def == nullptr ? lib_def_ : options.lib_def;
986 if (options.graph_collector != nullptr) {
987 GraphDef def;
988 graph->ToGraphDef(&def);
989 *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
990 options.graph_collector->CollectOptimizedGraph(def);
991 }
992
993 VLOG(4) << "Main function graph to be partitioned:";
994 VLOG(4) << DebugString(graph->ToGraphDefDebug());
995
996 std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
997 TF_RETURN_IF_ERROR(
998 PartitionFunctionGraph(*dev_set, std::move(graph), &subgraphs));
999
1000 for (const auto& pair : subgraphs) {
1001 DumpGraph(strings::StrCat("Before running POST_PARTITIONING passes (",
1002 pair.first, ")"),
1003 pair.second.get());
1004 }
1005
1006 GraphOptimizationPassOptions optimization_options;
1007 optimization_options.flib_def = &(optimized_graph_info.lib_def);
1008 optimization_options.is_function_graph = true;
1009 optimization_options.graph = nullptr;
1010 optimization_options.device_set = nullptr;
1011 optimization_options.partition_graphs = &subgraphs;
1012 // Normally POST_PARTITIONING passes are run by distributed workers.
1013 // Distributed workers are currently not supported in this code path, so we
1014 // run the passes here.
1015 const bool should_run_optimization_passes = !options.is_component_function;
1016 if (should_run_optimization_passes) {
1017 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
1018 OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
1019 }
1020
1021 for (const auto& pair : subgraphs) {
1022 const auto* optimized_subgraph = pair.second.get();
1023 DumpGraph(
1024 strings::StrCat("After all optimization passes (", pair.first, ")"),
1025 optimized_subgraph);
1026 if (VLOG_IS_ON(1)) {
1027 DumpGraphDefToFile(
1028 strings::StrCat("pflr_after_all_optimization_passes_",
1029 reinterpret_cast<uintptr_t>(optimized_subgraph), "_",
1030 pair.first),
1031 optimized_subgraph->ToGraphDefDebug());
1032 }
1033 }
1034 const uint64 optimization_end_time_usecs = Env::Default()->NowMicros();
1035 metrics::UpdateFunctionGraphOptimizationTime(optimization_end_time_usecs -
1036 optimization_start_time_usecs);
1037
1038 if (options.graph_collector != nullptr) {
1039 for (const auto& pair : subgraphs) {
1040 GraphDef def;
1041 pair.second->ToGraphDef(&def);
1042 *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
1043 options.graph_collector->CollectPartitionedGraph(def);
1044 }
1045 }
1046
1047 const auto& node_name_to_control_ret =
1048 optimized_graph_info.node_name_to_control_ret;
1049 // We must preserve control returns in each of the function components,
1050 // otherwise after function inlining we might prune side-effectful nodes.
1051 const auto control_ret =
1052 [&node_name_to_control_ret](const Node* n) -> absl::optional<string> {
1053 const auto it = node_name_to_control_ret.find(n->name());
1054 return it != node_name_to_control_ret.end()
1055 ? absl::make_optional<string>(it->second)
1056 : absl::nullopt;
1057 };
1058
1059 auto data = std::make_unique<MultiDeviceFunctionData>(
1060 function_name, function_key, optimized_graph_info.num_return_nodes,
1061 std::move(optimized_graph_info.lib_def),
1062 std::move(optimized_graph_info.ret_types));
1063
1064 int i = 0;
1065 // Generate a random function_name to avoid one function reuse the partition
1066 // function instantiated by another function.
1067 FunctionLibraryDefinition* data_lib_def = &data->lib_def_;
1068 FunctionNameGenerator name_generator(
1069 data_lib_def, absl::StrCat(function_name, "_", random::New64()));
1070 auto num_subgraphs = subgraphs.size();
1071 gtl::InlinedVector<Status, 4> instantiate_status(num_subgraphs);
1072 BlockingCounter counter(static_cast<int>(num_subgraphs));
1073 auto runner = [this, num_subgraphs](std::function<void()> fn) {
1074 // NOTE: Only use thread pool to instantiate sub-function when there are
1075 // more than 8 sub-functions. We want to avoid cost of switching thread when
1076 // there are only a few sub-functions.
1077 if (default_thread_pool_ != nullptr && num_subgraphs > 8) {
1078 default_thread_pool_->Schedule(fn);
1079 } else {
1080 fn();
1081 }
1082 };
1083
1084 // Before instantiating component functions, determine synchronous execution.
1085 data->enable_sync_execution = false;
1086 if (options.allow_small_function_optimizations) {
1087 data->enable_sync_execution = true;
1088 for (const auto& pair : subgraphs) {
1089 ComponentFunctionData* comp_data = &data->glue_[pair.first];
1090 const Graph* subgraph = pair.second.get();
1091 comp_data->async_attributes =
1092 AsyncAttributes(subgraph, options.allow_control_flow_sync_execution);
1093 if (comp_data->async_attributes.summary() ==
1094 AsyncAttributes::kAsyncRequired) {
1095 data->enable_sync_execution = false;
1096 }
1097 }
1098 }
1099
1100 // Instantiate each component function (subgraph).
1101 for (const auto& pair : subgraphs) {
1102 Status* status = &instantiate_status[i];
1103 string unique_name = name_generator.GetName();
1104 ComponentFunctionData* comp_data = &data->glue_[pair.first];
1105 runner([this, &pair, dev_set, comp_data, unique_name, data_lib_def,
1106 &control_ret, &options, status, &counter, &data] {
1107 const string& target = pair.first;
1108
1109 const string& device_type =
1110 dev_set->FindDeviceByName(target)->device_type();
1111 Graph* subgraph = pair.second.get();
1112
1113 bool ints_on_device =
1114 (device_type == "TPU" || device_type == "XLA_CPU" ||
1115 device_type == "XLA_GPU" || options.int_args_and_retvals_on_device);
1116 status->Update(UpdateArgAndRetvalMetadata(
1117 subgraph, &comp_data->arg_indices, &comp_data->ret_indices,
1118 &comp_data->arg_alloc_attrs, &comp_data->ret_alloc_attrs,
1119 ints_on_device));
1120 if (!status->ok()) {
1121 counter.DecrementCount();
1122 return;
1123 }
1124 FunctionDef shard;
1125 status->Update(
1126 GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard));
1127 if (!status->ok()) {
1128 counter.DecrementCount();
1129 return;
1130 }
1131 status->Update(data_lib_def->AddFunctionDef(shard));
1132 if (!status->ok()) {
1133 counter.DecrementCount();
1134 return;
1135 }
1136 FunctionLibraryRuntime::InstantiateOptions opts;
1137 opts.executor_type = options.executor_type;
1138 opts.target = target;
1139 opts.lib_def = data_lib_def;
1140 opts.create_kernels_eagerly = options.create_kernels_eagerly;
1141 opts.state_handle = options.state_handle;
1142 opts.allow_small_function_optimizations = data->enable_sync_execution;
1143 opts.allow_control_flow_sync_execution =
1144 options.allow_control_flow_sync_execution;
1145 AttrValue ints_on_device_attr;
1146 ints_on_device_attr.set_b(options.int_args_and_retvals_on_device);
1147 shard.mutable_attr()->insert(
1148 {FunctionLibraryDefinition::kIntsOnDeviceAttr, ints_on_device_attr});
1149 auto attrs = AttrSlice(&shard.attr());
1150 VLOG(1) << "Start instantiating component function " << unique_name
1151 << " on device " << target;
1152 VLOG(4) << DebugString(shard);
1153
1154 auto* component_handle = new FunctionLibraryRuntime::Handle;
1155 auto done = [this, status, unique_name, comp_data, component_handle,
1156 &data, &counter](const Status& s) {
1157 status->Update(s);
1158
1159 VLOG(1) << "Finished instantiating component function " << unique_name
1160 << " with handle " << *component_handle << " status: " << s;
1161 if (status->ok()) {
1162 {
1163 mutex_lock l(mu_);
1164 if (function_data_[*component_handle]->is_cross_process()) {
1165 data->is_cross_process_ = true;
1166 }
1167 }
1168 comp_data->handle = *component_handle;
1169 }
1170 delete component_handle;
1171 counter.DecrementCount();
1172 };
1173
1174 FunctionLibraryRuntime* flr = GetFLR(opts.target);
1175 if (flr != nullptr) {
1176 // Initialize local function synchronously.
1177 Status s = flr->Instantiate(unique_name, attrs, opts, component_handle);
1178 done(s);
1179 } else {
1180 opts.ret_indices = comp_data->ret_indices;
1181 // Initialize remote function asynchronously.
1182 InstantiateRemote(unique_name, attrs, opts, component_handle, done);
1183 }
1184 });
1185 i += 1;
1186 }
1187 counter.Wait();
1188 StatusGroup group;
1189 for (auto& status : instantiate_status) {
1190 group.Update(status);
1191 }
1192 TF_RETURN_IF_ERROR(group.as_summary_status());
1193
1194 *handle = AddMultiDeviceHandle(std::move(data), function_key);
1195 VLOG(2) << "Instantiated MultiDevice function \"" << function_name
1196 << "\" with handle " << *handle;
1197 return OkStatus();
1198}
1199
1200Status ProcessFunctionLibraryRuntime::GetOutputDevices(
1201 FunctionLibraryRuntime::Handle handle,
1202 std::vector<Device*>* output_devices) const {
1203 MultiDeviceFunctionData* data = IsMultiDevice(handle);
1204 if (data == nullptr) {
1205 return errors::InvalidArgument(
1206 "Failed for find multi-device function handle ", handle);
1207 }
1208
1209 for (const auto& pair : data->glue_) {
1210 const ComponentFunctionData& comp_data = pair.second;
1211 DCHECK(comp_data.ret_alloc_attrs.size() == comp_data.ret_indices.size());
1212 if (comp_data.ret_indices.empty()) {
1213 continue;
1214 }
1215
1216 const string& target = pair.first;
1217 FunctionLibraryRuntime* target_flr = GetFLR(target);
1218 Device* target_device = nullptr;
1219 Device* host = nullptr;
1220 if (target_flr == nullptr) {
1221 if (!data->has_remote_outputs) {
1222 data->has_remote_outputs = true;
1223 }
1224 target_device = device_set()->FindDeviceByName(target);
1225 string remote_host;
1226 TF_RETURN_IF_ERROR(
1227 DeviceNameUtils::DeviceNameToCpuDeviceName(target, &remote_host));
1228 host = device_set()->FindDeviceByName(remote_host);
1229 } else {
1230 target_device = target_flr->device();
1231 }
1232 output_devices->resize(data->num_outputs_);
1233 for (int j = 0; j < comp_data.ret_indices.size(); ++j) {
1234 int ret_index = comp_data.ret_indices[j];
1235 if (data->ret_types_[ret_index] == DT_RESOURCE) {
1236 (*output_devices)[ret_index] = target_device;
1237 } else {
1238 (*output_devices)[ret_index] =
1239 comp_data.ret_alloc_attrs[j].on_host() ? host : target_device;
1240 }
1241 }
1242 }
1243
1244 return OkStatus();
1245}
1246
1247Status ProcessFunctionLibraryRuntime::PrepareRunMultiDevice(
1248 const FunctionLibraryRuntime::Options& opts,
1249 FunctionLibraryRuntime::Handle handle,
1250 const MultiDeviceFunctionData** data) const {
1251 if (opts.create_rendezvous) {
1252 // FLR->Run() is the default entry point. It checks for cancellation,
1253 // creates rendezvous, etc.
1254 // Letting create_rendezvous through will do the wrong thing - each
1255 // component function will get a separate rendezvous created by its FLR.
1256 return errors::Internal(
1257 "Cannot call ProcessFunctionLibraryRuntime::Run with "
1258 "create_rendezvous=true. Please run the function "
1259 "using FunctionLibraryRuntime::Run");
1260 }
1261
1262 *data = IsMultiDevice(handle);
1263 if (*data == nullptr) {
1264 return errors::NotFound("Multi-device function handle ", handle,
1265 "not found. Was the function instantiated?");
1266 }
1267
1268 // Check whether we have the right rendezvous.
1269 if (opts.rendezvous && (*data)->is_cross_process_ &&
1270 !opts.rendezvous->is_cross_process()) {
1271 return errors::InvalidArgument(
1272 "Running a cross process function ", (*data)->function_name_,
1273 " without an appropriate cross process Rendezvous.");
1274 }
1275
1276 return OkStatus();
1277}
1278
1279std::vector<string> ProcessFunctionLibraryRuntime::GetOrderedSubgraphs(
1280 const MultiDeviceFunctionData* data) const {
1281 std::vector<string> subgraph_keys;
1282 subgraph_keys.reserve(data->glue_.size());
1283 for (const auto& pair : data->glue_) {
1284 subgraph_keys.push_back(pair.first);
1285 }
1286 auto send_first_ordering = [&](const string& a, const string& b) {
1287 auto a_summary = data->glue_.at(a).async_attributes.summary();
1288 auto b_summary = data->glue_.at(b).async_attributes.summary();
1289 if (a_summary == b_summary) {
1290 return false;
1291 }
1292 if (a_summary == AsyncAttributes::kSendOnly) {
1293 return true;
1294 }
1295 return false;
1296 };
1297 std::sort(subgraph_keys.begin(), subgraph_keys.end(), send_first_ordering);
1298 return subgraph_keys;
1299}
1300
1301Status ProcessFunctionLibraryRuntime::RunMultiDeviceSync(
1302 const FunctionLibraryRuntime::Options& opts,
1303 FunctionLibraryRuntime::Handle outer_handle, std::vector<FunctionRet>* rets,
1304 std::function<Status(const ComponentFunctionData& comp_data,
1305 InternalArgs* args)>
1306 get_component_args) const {
1307 const MultiDeviceFunctionData* data;
1308 Status prepare_status = PrepareRunMultiDevice(opts, outer_handle, &data);
1309 if (!prepare_status.ok()) {
1310 return prepare_status;
1311 }
1312
1313 FunctionLibraryRuntime::Options opts_copy = opts;
1314
1315 // Sort the subgraphs topologically before execution to avoid deadlock:
1316 //
1317 // Because subgraphs will not execute in parallel here, dependencies between
1318 // subgraphs cannot be resolved automatically. In contrast, with multi-
1319 // threaded execution, we launch all subgraphs at once, asynchronously, and
1320 // allow any to block mid-execution while its dependencies are resolved.
1321 //
1322 // In this synchronous execution path, currently supported ops with inter-
1323 // subgraph dependencies are send and receive. As `_Send` and `_HostSend`
1324 // are non-blocking, we run subgraphs with those first, and those with
1325 // the blocking '_Recv' and '_HostRecv' ops will have their dependencies
1326 // resolved before execution.
1327 //
1328 // We assume that the partitioning has a valid deadlock-free ordering and the
1329 // safety of running synchronously has already been confirmed by this point.
1330 std::vector<string> subgraph_keys = GetOrderedSubgraphs(data);
1331
1332 for (const string& target : subgraph_keys) {
1333 const ComponentFunctionData& comp_data = data->glue_.at(target);
1334 FunctionLibraryRuntime::Handle comp_handle = comp_data.handle;
1335
1336 opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs;
1337 opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs;
1338
1339 InternalArgs comp_args;
1340 Status args_status = get_component_args(comp_data, &comp_args);
1341 if (!args_status.ok()) {
1342 VLOG(2) << "Failed to get component function arguments: " << args_status;
1343 return args_status;
1344 }
1345 rets->resize(data->num_outputs_);
1346
1347 VLOG(1) << "Running component function on device " << target << " from "
1348 << data->function_name_ << " with handle " << comp_handle;
1349 FunctionLibraryRuntime* flr = GetFLR(target);
1350 if (flr != nullptr) {
1351 opts_copy.remote_execution = false;
1352 // When target device has private thread pool, use the target device
1353 // runner
1354 thread::ThreadPool* pool = flr->device()->tensorflow_device_thread_pool();
1355 opts_copy.runner = (pool == nullptr) ? opts.runner : flr->runner();
1356 VLOG(4) << " with " << opts_copy.DebugString();
1357
1358 std::vector<Tensor> comp_tensor_rets;
1359 Status run_status =
1360 flr->RunSync(opts_copy, comp_handle, GetLocalArgs(comp_args.args),
1361 &comp_tensor_rets);
1362 if (!run_status.ok()) {
1363 VLOG(2) << "Component function execution failed: " << run_status;
1364 const string function_and_msg = strings::StrCat(
1365 errors::FormatFunctionForError(data->function_name_), " ",
1366 run_status.error_message());
1367 if (opts.rendezvous != nullptr) opts.rendezvous->StartAbort(run_status);
1368 return errors::CreateWithUpdatedMessage(run_status, function_and_msg);
1369 } else {
1370 VLOG(2) << "Component function execution succeeded.";
1371 for (int i = 0; i < comp_tensor_rets.size(); ++i) {
1372 (*rets)[comp_data.ret_indices[i]] = comp_tensor_rets[i];
1373 }
1374 }
1375 } else {
1376 // Fall back to DistributedFunctionLibraryRuntime for remote execution.
1377 opts_copy.remote_execution = true;
1378 VLOG(4) << " with " << opts_copy.DebugString();
1379
1380 std::vector<std::unique_ptr<CleanUpItem>> cleanup_items;
1381 Notification n;
1382 Status s;
1383 std::vector<FunctionRet> comp_rets;
1384 RunInternal(opts_copy, comp_handle, comp_args.args, &comp_rets,
1385 &cleanup_items, [&n, &s](const Status& status) {
1386 s.Update(status);
1387 n.Notify();
1388 });
1389 n.WaitForNotification();
1390 return s;
1391 }
1392 }
1393 return OkStatus();
1394}
1395
1396void ProcessFunctionLibraryRuntime::RunMultiDeviceAsync(
1397 const FunctionLibraryRuntime::Options& opts,
1398 FunctionLibraryRuntime::Handle outer_handle, std::vector<FunctionRet>* rets,
1399 std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
1400 FunctionLibraryRuntime::DoneCallback done,
1401 std::function<Status(const ComponentFunctionData& comp_data,
1402 InternalArgs* args)>
1403 get_component_args) const {
1404 const MultiDeviceFunctionData* data;
1405 Status prepare_status = PrepareRunMultiDevice(opts, outer_handle, &data);
1406 if (!prepare_status.ok()) {
1407 done(prepare_status);
1408 return;
1409 }
1410
1411 // A locally created cancellation manager, used only when the caller does not
1412 // provide one in argument.
1413 std::shared_ptr<CancellationManager> local_cm;
1414 CancellationManager* cm = opts.cancellation_manager;
1415 if (cm == nullptr) {
1416 local_cm = std::make_shared<CancellationManager>();
1417 cm = local_cm.get();
1418 }
1419
1420 auto* refcounted_done = new ReffedStatusCallback(std::move(done));
1421 for (int i = 0; i < data->glue_.size(); ++i) {
1422 refcounted_done->Ref();
1423 }
1424
1425 FunctionLibraryRuntime::Options opts_copy = opts;
1426 for (const auto& pair : data->glue_) {
1427 const string& target = pair.first;
1428 const ComponentFunctionData& comp_data = pair.second;
1429 FunctionLibraryRuntime::Handle comp_handle = pair.second.handle;
1430
1431 opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs;
1432 opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs;
1433 opts_copy.cancellation_manager = cm;
1434
1435 InternalArgs comp_args;
1436 Status s = get_component_args(comp_data, &comp_args);
1437 if (!s.ok()) {
1438 VLOG(2) << "Failed to get component function arguments: " << s;
1439 refcounted_done->UpdateStatus(s);
1440 refcounted_done->Unref();
1441 cm->StartCancel();
1442 continue;
1443 }
1444 std::vector<FunctionRet>* comp_rets = new std::vector<FunctionRet>;
1445 rets->resize(data->num_outputs_);
1446
1447 auto component_fn_callback = [comp_rets, rets, comp_data, refcounted_done,
1448 cm, local_cm, data, comp_handle,
1449 target](const Status& status) {
1450 if (!status.ok()) {
1451 VLOG(2) << "Component function execution on target " << target
1452 << " from " << data->function_name_ << " with handle "
1453 << comp_handle << " failed: " << status;
1454 const string function_and_msg = strings::StrCat(
1455 errors::FormatFunctionForError(data->function_name_), " ",
1456 status.error_message());
1457 refcounted_done->UpdateStatus(
1458 errors::CreateWithUpdatedMessage(status, function_and_msg));
1459 // Cancel the execution of other component functions.
1460 cm->StartCancel();
1461 } else {
1462 VLOG(2) << "Component function execution on target " << target
1463 << " from " << data->function_name_ << " with handle "
1464 << comp_handle << " succeeded.";
1465 for (int i = 0; i < comp_rets->size(); ++i) {
1466 (*rets)[comp_data.ret_indices[i]] = (*comp_rets)[i];
1467 }
1468 }
1469 delete comp_rets;
1470 // refcounted_done is thread-safe
1471 refcounted_done->Unref();
1472 };
1473
1474 FunctionLibraryRuntime* flr = GetFLR(target);
1475 if (flr != nullptr) {
1476 opts_copy.remote_execution = false;
1477 // When target device has private thread pool, use the target device
1478 // runner
1479 thread::ThreadPool* pool = flr->device()->tensorflow_device_thread_pool();
1480 opts_copy.runner = (pool == nullptr) ? opts.runner : flr->runner();
1481
1482 VLOG(1) << "Running component function on device " << target << " from "
1483 << data->function_name_ << " with handle " << comp_handle;
1484 VLOG(4) << " with " << opts_copy.DebugString();
1485
1486 std::vector<Tensor>* comp_tensor_rets = new std::vector<Tensor>;
1487 flr->Run(
1488 opts_copy, comp_handle, GetLocalArgs(comp_args.args),
1489 comp_tensor_rets,
1490 TensorsToFunctionRetsDoneCallback(comp_rets, comp_tensor_rets,
1491 std::move(component_fn_callback)));
1492 } else {
1493 opts_copy.remote_execution = true;
1494
1495 VLOG(1) << "Running component function on device " << target << " from "
1496 << data->function_name_ << " with handle " << comp_handle;
1497 VLOG(4) << " with " << opts_copy.DebugString();
1498
1499 RunInternal(opts_copy, comp_handle, comp_args.args, comp_rets,
1500 cleanup_items, std::move(component_fn_callback));
1501 }
1502 }
1503 refcounted_done->Unref();
1504}
1505
1506Status ProcessFunctionLibraryRuntime::Instantiate(
1507 const string& function_name, AttrSlice attrs,
1508 const FunctionLibraryRuntime::InstantiateOptions& options,
1509 FunctionLibraryRuntime::Handle* handle) {
1510 if (options.is_multi_device_function) {
1511 return InstantiateMultiDevice(function_name, attrs, options, handle);
1512 }
1513
1514 *handle = kInvalidHandle;
1515 FunctionLibraryRuntime* flr = GetFLR(options.target);
1516 if (flr != nullptr) {
1517 return flr->Instantiate(function_name, attrs, options, handle);
1518 }
1519
1520 Status status;
1521 Notification notification;
1522 InstantiateRemote(function_name, attrs, options, handle,
1523 [&status, &notification](const Status& s) {
1524 status = s;
1525 notification.Notify();
1526 });
1527 notification.WaitForNotification();
1528 return status;
1529}
1530
1531Status ProcessFunctionLibraryRuntime::IsCrossProcess(
1532 FunctionLibraryRuntime::Handle handle, bool* is_cross_process) const {
1533 tf_shared_lock l(mu_);
1534 const auto& mdevice_it = mdevice_data_.find(handle);
1535 if (mdevice_it != mdevice_data_.end()) {
1536 *is_cross_process = mdevice_it->second->is_cross_process_;
1537 return OkStatus();
1538 }
1539 const auto& it = function_data_.find(handle);
1540 if (it != function_data_.end()) {
1541 *is_cross_process = it->second->is_cross_process();
1542 return OkStatus();
1543 }
1544 return errors::InvalidArgument("Handle ", handle, " not found.");
1545}
1546
1547void ProcessFunctionLibraryRuntime::InstantiateRemote(
1548 const string& function_name, AttrSlice attrs,
1549 const FunctionLibraryRuntime::InstantiateOptions& options,
1550 FunctionLibraryRuntime::Handle* handle,
1551 FunctionLibraryRuntime::DoneCallback done) {
1552 if (parent_ == nullptr) {
1553 done(errors::Internal(
1554 "Currently don't support instantiating functions on device: ",
1555 options.target));
1556 return;
1557 }
1558 auto target = options.target;
1559 VLOG(1) << "ProcessFLR Instantiate: " << function_name << " on: " << target;
1560 string function_key = Canonicalize(function_name, attrs, options);
1561 FunctionData* f;
1562 {
1563 mutex_lock l(mu_);
1564 FunctionLibraryRuntime::Handle h =
1565 gtl::FindWithDefault(table_, function_key, kInvalidHandle);
1566 if (h == kInvalidHandle || function_data_.count(h) == 0) {
1567 h = AddHandleLocked(function_key, target, kInvalidHandle);
1568 }
1569 f = function_data_[h].get();
1570 *handle = h;
1571 }
1572 f->DistributedInit(
1573 parent_, function_name,
1574 options.lib_def == nullptr ? *lib_def_ : *options.lib_def, attrs, options,
1575 [this, function_name, target, handle, done](const Status& s) {
1576 VLOG(1) << "ProcessFLR Instantiate [success]: " << function_name
1577 << " on: " << target << " with handle: " << *handle
1578 << " (this: " << this << ")";
1579 done(s);
1580 });
1581}
1582
1583Status ProcessFunctionLibraryRuntime::RemoveHandle(
1584 FunctionLibraryRuntime::Handle handle) {
1585 mutex_lock l(mu_);
1586 table_.erase(function_data_[handle]->function_key());
1587 function_data_.erase(handle);
1588 return OkStatus();
1589}
1590
1591Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle(
1592 FunctionLibraryRuntime::Handle handle) {
1593 std::unique_ptr<MultiDeviceFunctionData> mdata;
1594 {
1595 mutex_lock l(mu_);
1596 auto it = mdevice_data_.find(handle);
1597 --it->second->instantiation_counter_;
1598 if (it->second->instantiation_counter_ != 0) {
1599 return OkStatus();
1600 }
1601 mdata = std::move(it->second);
1602 table_.erase(mdata->function_key_);
1603 mdevice_data_.erase(it);
1604 }
1605
1606 // If we are here we are releasing the last instantiation of `handle`.
1607 // Release all component function handles.
1608 Status overall_status;
1609 for (const auto& it : mdata->glue_) {
1610 const string& device = it.first;
1611 FunctionLibraryRuntime::Handle flr_handle = it.second.handle;
1612 FunctionLibraryRuntime* flr = GetFLR(device);
1613 if (flr == nullptr) {
1614 // TODO(nareshmodi): Implement DeregisterGraph call to remote device if
1615 // parent is not null.
1616 if (parent_ != nullptr) {
1617 return errors::Unimplemented(
1618 "Releasing a multi-device component handle on a remote device is "
1619 "not yet implemented.");
1620 }
1621 return errors::InvalidArgument(
1622 "Failed to find FunctionLibraryRuntime for device ", device,
1623 " when releasing multi-device function handle ", handle);
1624 }
1625 Status status = flr->ReleaseHandle(flr_handle);
1626 if (!status.ok()) {
1627 overall_status = status;
1628 }
1629 }
1630
1631 return overall_status;
1632}
1633
1634Status ProcessFunctionLibraryRuntime::ReleaseHandle(
1635 FunctionLibraryRuntime::Handle handle) {
1636 // Return directly if all function handles has already been released.
1637 if (flr_map_ == nullptr) return OkStatus();
1638
1639 if (IsMultiDevice(handle)) {
1640 return ReleaseMultiDeviceHandle(handle);
1641 }
1642
1643 FunctionLibraryRuntime* flr = nullptr;
1644 string target_device;
1645 {
1646 mutex_lock l(mu_);
1647 CHECK_EQ(1, function_data_.count(handle)) << " handle: " << handle;
1648 target_device = function_data_[handle]->target_device();
1649 }
1650 flr = GetFLR(target_device);
1651 if (flr != nullptr) {
1652 return flr->ReleaseHandle(handle);
1653 }
1654 return errors::InvalidArgument("Handle not found: ", handle);
1655}
1656
1657void ProcessFunctionLibraryRuntime::CleanupCreatedRendezvous(
1658 const Rendezvous* created_rendezvous, const int64_t step_id) const {
1659 if (created_rendezvous) {
1660 DCHECK(rendezvous_factory_);
1661 created_rendezvous->Unref();
1662 Status s = rendezvous_factory_.CleanUp(step_id);
1663 if (!s.ok()) {
1664 LOG(ERROR) << s;
1665 }
1666 }
1667}
1668
1669FunctionLibraryRuntime::DoneCallback
1670ProcessFunctionLibraryRuntime::ApplyCleanUpToDoneCallback(
1671 std::vector<std::unique_ptr<CleanUpItem>>* items,
1672 FunctionLibraryRuntime::DoneCallback done, const int64_t step_id,
1673 const Rendezvous* created_rendezvous) const {
1674 return [this, items, done = std::move(done), step_id,
1675 created_rendezvous](const Status& status) {
1676 this->CleanupCreatedRendezvous(created_rendezvous, step_id);
1677 auto* local_status = new Status(status);
1678 CleanUp(items, [local_status, done](const Status& cleanup_status) {
1679 local_status->Update(cleanup_status);
1680 done(*local_status);
1681 delete local_status;
1682 });
1683 delete items;
1684 };
1685}
1686
1687Status ProcessFunctionLibraryRuntime::CreateRendezvous(
1688 FunctionLibraryRuntime::Options& opts,
1689 Rendezvous** created_rendezvous) const {
1690 DCHECK(opts.rendezvous == nullptr);
1691 if (!rendezvous_factory_) {
1692 return errors::FailedPrecondition(
1693 "The caller does not provide a rendezvous and "
1694 "ProcessFunctionLibraryRuntime was created without a rendezvous "
1695 "factory.");
1696 }
1697 Status s = rendezvous_factory_(opts.step_id, device_mgr_, created_rendezvous);
1698 if (s.ok()) {
1699 opts.rendezvous = *created_rendezvous;
1700 opts.create_rendezvous = false;
1701 }
1702 return s;
1703}
1704
1705Status ProcessFunctionLibraryRuntime::GetComponentArgs(
1706 const gtl::ArraySlice<Tensor> args,
1707 const ProcessFunctionLibraryRuntime::ComponentFunctionData& comp_data,
1708 ProcessFunctionLibraryRuntime::InternalArgs* comp_args) {
1709 // "Index"s of _Arg nodes are unique when all arguments are local Tensors.
1710 for (const auto& it : comp_data.arg_indices) {
1711 if (it.index >= args.size()) {
1712 return errors::InvalidArgument("index ", it.index,
1713 " is out of range [0, ", args.size(), ")");
1714 }
1715 if (it.sub_index >= 0) {
1716 const Tensor& t = args[it.index];
1717 if (t.dtype() != DT_RESOURCE) {
1718 return errors::InvalidArgument("Got unexpected sub_index ",
1719 it.sub_index, " for argument ",
1720 it.index);
1721 }
1722 const auto& handles = t.flat<ResourceHandle>();
1723 if (it.sub_index >= handles.size()) {
1724 return errors::InvalidArgument("Sub_index ", it.sub_index,
1725 "is out of range [0,", handles.size(),
1726 ") for argument ", it.index);
1727 }
1728 comp_args->args.push_back(Tensor(handles(it.sub_index)));
1729 } else {
1730 comp_args->args.push_back(args[it.index]);
1731 }
1732 }
1733 return OkStatus();
1734}
1735
1736#if !defined(IS_MOBILE_PLATFORM)
1737Status ProcessFunctionLibraryRuntime::GetComponentArgs(
1738 const FunctionArgsInterface& args,
1739 const ProcessFunctionLibraryRuntime::ComponentFunctionData& comp_data,
1740 ProcessFunctionLibraryRuntime::InternalArgs* comp_args) {
1741 for (int i = 0; i < comp_data.arg_indices.size(); ++i) {
1742 const FunctionArgIndex index = comp_data.arg_indices.at(i);
1743 Tensor tensor;
1744 if (args.GetLocalArg(index, &tensor).ok()) {
1745 comp_args->args.push_back(std::move(tensor));
1746 } else {
1747 eager::RemoteTensorHandle remote_handle;
1748 TF_RETURN_IF_ERROR(args.GetRemoteArg(index, &remote_handle));
1749 comp_args->remote_args.emplace_back(
1750 std::make_unique<eager::RemoteTensorHandle>(
1751 std::move(remote_handle)));
1752 comp_args->args.push_back(comp_args->remote_args.back().get());
1753 }
1754 }
1755 return OkStatus();
1756}
1757#endif // IS_MOBILE_PLATFORM
1758
1759void ProcessFunctionLibraryRuntime::Run(
1760 const FunctionLibraryRuntime::Options& opts,
1761 FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
1762 std::vector<Tensor>* rets,
1763 FunctionLibraryRuntime::DoneCallback done) const {
1764 FunctionLibraryRuntime::Options new_opts = opts;
1765 Rendezvous* created_rendezvous = nullptr;
1766 if (!opts.rendezvous) {
1767 Status s = CreateRendezvous(new_opts, &created_rendezvous);
1768 if (!s.ok()) {
1769 done(s);
1770 return;
1771 }
1772 }
1773
1774 auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
1775 done = ApplyCleanUpToDoneCallback(cleanup_items, std::move(done),
1776 new_opts.step_id, created_rendezvous);
1777 std::vector<FunctionRet>* function_rets = new std::vector<FunctionRet>;
1778 done = [rets, function_rets, done = std::move(done)](const Status& s) {
1779 Status status = s;
1780 if (status.ok()) {
1781 status.Update(FunctionRetsToTensors(function_rets, rets));
1782 }
1783 delete function_rets;
1784 done(status);
1785 };
1786 bool multi_device = HasMultiDeviceHandle(handle);
1787 if (multi_device) {
1788 auto get_component_args = [&args](const ComponentFunctionData& comp_data,
1789 InternalArgs* comp_args) -> Status {
1790 return GetComponentArgs(args, comp_data, comp_args);
1791 };
1792 return RunMultiDeviceAsync(new_opts, handle, function_rets, cleanup_items,
1793 std::move(done), std::move(get_component_args));
1794 }
1795 std::vector<FunctionArg> local_args;
1796 for (const auto& tensor : args) {
1797 local_args.push_back(tensor);
1798 }
1799 RunInternal(new_opts, handle, local_args, function_rets, cleanup_items,
1800 std::move(done));
1801}
1802
1803// This method handles the simple remote call case (not multi-device).
1804void ProcessFunctionLibraryRuntime::RunInternal(
1805 const FunctionLibraryRuntime::Options& opts,
1806 FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<FunctionArg> args,
1807 std::vector<FunctionRet>* rets,
1808 std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items,
1809 FunctionLibraryRuntime::DoneCallback done) const {
1810 FunctionLibraryRuntime* flr = nullptr;
1811 string target_device;
1812 FunctionLibraryRuntime::LocalHandle local_handle;
1813 {
1814 tf_shared_lock l(mu_);
1815 auto iter = function_data_.find(handle);
1816 if (iter == function_data_.end()) {
1817 done(errors::NotFound("Handle: ", handle, " not found."));
1818 return;
1819 }
1820 FunctionData* function_data = iter->second.get();
1821 target_device = function_data->target_device();
1822 local_handle = function_data->local_handle();
1823 }
1824
1825 if (!opts.remote_execution) {
1826 done(
1827 errors::InvalidArgument("ProcessFunctionLibraryRuntime::Run should "
1828 "only be called for multi-device functions or "
1829 "for remote execution."));
1830 return;
1831 }
1832
1833 flr = GetFLR(target_device);
1834 if (flr != nullptr) {
1835 auto rendezvous = opts.rendezvous;
1836 string source_device = opts.source_device;
1837 DeviceContext* device_context;
1838 Status s = GetDeviceContext(source_device, &device_context);
1839 if (!s.ok()) {
1840 done(s);
1841 return;
1842 }
1843 int64_t src_incarnation, target_incarnation;
1844 s = GetDeviceIncarnation(source_device, &src_incarnation);
1845 s.Update(GetDeviceIncarnation(target_device, &target_incarnation));
1846 if (!s.ok()) {
1847 done(s);
1848 return;
1849 }
1850
1851 std::vector<Tensor> local_args = GetLocalArgs(args);
1852
1853 // Send the args over to the target device.
1854 s = SendTensors(source_device, target_device, "arg_", src_incarnation,
1855 local_args, device_context, opts.args_alloc_attrs,
1856 rendezvous);
1857 if (!s.ok()) {
1858 done(s);
1859 return;
1860 }
1861 const std::vector<AllocatorAttributes>& rets_alloc_attrs =
1862 opts.rets_alloc_attrs;
1863 std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
1864 flr->Run(opts, handle, local_args, remote_rets,
1865 [source_device, target_device, target_incarnation, rendezvous,
1866 device_context, rets_alloc_attrs, remote_rets, rets,
1867 done = std::move(done)](const Status& status) mutable {
1868 if (!status.ok()) {
1869 delete remote_rets;
1870 done(status);
1871 return;
1872 }
1873 int64_t num_returns = remote_rets->size();
1874 delete remote_rets;
1875 // Now receive the return values from the target.
1876 std::vector<Tensor>* recv_tensors = new std::vector<Tensor>;
1877 ReceiveTensorsAsync(target_device, source_device, "ret_",
1878 target_incarnation, num_returns,
1879 device_context, rets_alloc_attrs, rendezvous,
1880 recv_tensors,
1881 TensorsToFunctionRetsDoneCallback(
1882 rets, recv_tensors, std::move(done)));
1883 });
1884 return;
1885 }
1886 if (parent_ != nullptr) {
1887 auto cleanup_item = std::make_unique<CleanUpItem>();
1888 cleanup_item->device = target_device;
1889 cleanup_item->step_id = opts.step_id;
1890 cleanup_item->local_handle = local_handle;
1891 cleanup_items->emplace_back(std::move(cleanup_item));
1892 parent_->Run(opts, local_handle, args, rets, std::move(done));
1893 return;
1894 }
1895 done(errors::Internal("Could not find device"));
1896}
1897
1898void ProcessFunctionLibraryRuntime::Run(
1899 const FunctionLibraryRuntime::Options& opts,
1900 FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame,
1901 FunctionLibraryRuntime::DoneCallback done) const {
1902 std::vector<Tensor> args;
1903 args.reserve(frame->num_args());
1904 for (size_t i = 0; i < frame->num_args(); ++i) {
1905 const Tensor* arg;
1906 Status s = frame->GetArg(i, &arg);
1907 args.emplace_back(*arg);
1908 if (!s.ok()) {
1909 done(s);
1910 }
1911 }
1912 std::vector<Tensor>* rets = new std::vector<Tensor>;
1913 rets->reserve(frame->num_retvals());
1914
1915 Run(opts, handle, args, rets,
1916
1917 [frame, rets, done = std::move(done)](const Status& status) {
1918 std::unique_ptr<std::vector<Tensor>> rets_releaser(rets);
1919
1920 if (!status.ok()) {
1921 done(status);
1922 return;
1923 }
1924
1925 if (rets->size() != frame->num_retvals()) {
1926 done(errors::Internal(
1927 "Number of return values from function (", rets->size(),
1928 ") did not match expected number of return values (",
1929 frame->num_retvals(), ")."));
1930 return;
1931 }
1932
1933 for (size_t i = 0; i < frame->num_retvals(); ++i) {
1934 Status s = frame->SetRetval(i, (*rets)[i]);
1935 if (!s.ok()) {
1936 done(s);
1937 return;
1938 }
1939 }
1940 done(OkStatus());
1941 });
1942}
1943
1944Status ProcessFunctionLibraryRuntime::RunSync(
1945 const FunctionLibraryRuntime::Options& orig_opts,
1946 FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
1947 std::vector<Tensor>* rets) const {
1948 MultiDeviceFunctionData* multi_device_data = IsMultiDevice(handle);
1949 if (multi_device_data && multi_device_data->enable_sync_execution) {
1950 metrics::IncrementTestCounter("pflr_runsync", "sync");
1951 FunctionLibraryRuntime::Options new_opts = orig_opts;
1952 Rendezvous* created_rendezvous = nullptr;
1953 if (!new_opts.rendezvous) {
1954 TF_RETURN_IF_ERROR(CreateRendezvous(new_opts, &created_rendezvous));
1955 }
1956
1957 std::vector<FunctionRet> function_rets;
1958 auto get_component_args = [&args](const ComponentFunctionData& comp_data,
1959 InternalArgs* comp_args) {
1960 return GetComponentArgs(args, comp_data, comp_args);
1961 };
1962
1963 Status status = RunMultiDeviceSync(new_opts, handle, &function_rets,
1964 std::move(get_component_args));
1965 CleanupCreatedRendezvous(created_rendezvous, new_opts.step_id);
1966 status.Update(FunctionRetsToTensors(&function_rets, rets));
1967 return status;
1968 } else {
1969 // TODO(b/207484417): Either handle or avoid/delete this fallback path.
1970 metrics::IncrementTestCounter("pflr_runsync", "async");
1971 Notification n;
1972 Status s;
1973 Run(orig_opts, handle, args, rets, [&n, &s](const Status& status) {
1974 s.Update(status);
1975 n.Notify();
1976 });
1977 n.WaitForNotification();
1978 return s;
1979 }
1980}
1981
1982Status ProcessFunctionLibraryRuntime::RunSync(
1983 const FunctionLibraryRuntime::Options& opts,
1984 FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame) const {
1985 // TODO(b/207485199): Implement this as synchronous code.
1986 Notification n;
1987 Status s;
1988 Run(opts, handle, frame, [&n, &s](const Status& status) {
1989 s.Update(status);
1990 n.Notify();
1991 });
1992 n.WaitForNotification();
1993 return s;
1994}
1995
1996void ProcessFunctionLibraryRuntime::Run(
1997 const FunctionLibraryRuntime::Options& opts,
1998 FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
1999 std::vector<FunctionRet>* rets,
2000 FunctionLibraryRuntime::DoneCallback done) const {
2001 bool has_remote_outputs = false;
2002 const MultiDeviceFunctionData* data = IsMultiDevice(handle);
2003 if (data != nullptr) {
2004 has_remote_outputs = data->has_remote_outputs;
2005 }
2006 if (!args.HasRemoteOrPackedInputs() && !has_remote_outputs) {
2007 const std::vector<Tensor> local_inputs = args.GetLocalTensors();
2008 std::vector<Tensor>* tensor_rets = new std::vector<Tensor>;
2009 return Run(
2010 opts, handle, local_inputs, tensor_rets,
2011 TensorsToFunctionRetsDoneCallback(rets, tensor_rets, std::move(done)));
2012 }
2013
2014 FunctionLibraryRuntime::Options new_opts = opts;
2015 Rendezvous* created_rendezvous = nullptr;
2016 if (!opts.rendezvous) {
2017 Status s = CreateRendezvous(new_opts, &created_rendezvous);
2018 if (!s.ok()) {
2019 done(s);
2020 return;
2021 }
2022 }
2023
2024#if defined(IS_MOBILE_PLATFORM)
2025 done(errors::Unimplemented(
2026 "Remote inputs are not available on mobile devices."));
2027 return;
2028#else // !IS_MOBILE_PLATFORM
2029 auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
2030 done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id,
2031 created_rendezvous);
2032
2033 auto get_component_args = [&args](const ComponentFunctionData& comp_data,
2034 InternalArgs* comp_args) -> Status {
2035 return GetComponentArgs(args, comp_data, comp_args);
2036 };
2037 return RunMultiDeviceAsync(new_opts, handle, rets, cleanup_items,
2038 std::move(done), std::move(get_component_args));
2039#endif // !IS_MOBILE_PLATFORM
2040}
2041
2042void ProcessFunctionLibraryRuntime::CleanUp(
2043 std::vector<std::unique_ptr<CleanUpItem>>* items,
2044 FunctionLibraryRuntime::DoneCallback done) const {
2045 auto* refcounted_done = new ReffedStatusCallback(std::move(done));
2046 for (auto& item : *items) {
2047 refcounted_done->Ref();
2048 auto* flr = GetFLR(item->device);
2049 if (flr != nullptr) {
2050 // TODO(fishx): cleanup state for local execution.
2051 refcounted_done->UpdateStatus(
2052 errors::Internal("Cleanup items shouldn't contain local item."));
2053 refcounted_done->Unref();
2054 } else if (parent_ != nullptr) {
2055 parent_->CleanUp(item->step_id, item->local_handle,
2056 [refcounted_done](const Status& status) {
2057 if (!status.ok()) {
2058 refcounted_done->UpdateStatus(status);
2059 }
2060 // refcounted_done is thread-safe
2061 refcounted_done->Unref();
2062 });
2063 } else {
2064 refcounted_done->UpdateStatus(
2065 errors::Internal("Could not find device in cleanup."));
2066 refcounted_done->Unref();
2067 }
2068 }
2069 refcounted_done->Unref();
2070}
2071
2072Status ProcessFunctionLibraryRuntime::Clone(
2073 Env* env, int graph_def_version, const OptimizerOptions& optimizer_options,
2074 std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
2075 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
2076 bool skip_flib_def) const {
2077 if (skip_flib_def) {
2078 *out_lib_def = std::make_unique<FunctionLibraryDefinition>(
2079 lib_def_->default_registry(), FunctionDefLibrary{});
2080 } else {
2081 *out_lib_def = std::make_unique<FunctionLibraryDefinition>(*lib_def_);
2082 }
2083 *out_pflr = std::make_unique<ProcessFunctionLibraryRuntime>(
2084 device_mgr_, env, config_ ? &(*config_) : nullptr, graph_def_version,
2085 out_lib_def->get(), optimizer_options, default_thread_pool_, parent_,
2086 session_metadata_, rendezvous_factory_);
2087 {
2088 tf_shared_lock l(mu_);
2089 for (auto* d : composite_devices_) (*out_pflr)->AddCompositeDevice(d);
2090 }
2091 return OkStatus();
2092}
2093
2094} // namespace tensorflow
2095