1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/common_runtime/process_function_library_runtime.h" |
16 | |
17 | #include <algorithm> |
18 | #include <functional> |
19 | #include <iterator> |
20 | #include <optional> |
21 | #include <utility> |
22 | |
23 | #include "absl/container/flat_hash_map.h" |
24 | #include "absl/memory/memory.h" |
25 | #include "absl/strings/str_join.h" |
26 | #include "tensorflow/core/common_runtime/device_set.h" |
27 | #include "tensorflow/core/common_runtime/function.h" |
28 | #include "tensorflow/core/common_runtime/function_optimization_registry.h" |
29 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
30 | #include "tensorflow/core/common_runtime/optimization_registry.h" |
31 | #include "tensorflow/core/common_runtime/partitioning_utils.h" |
32 | #include "tensorflow/core/common_runtime/placer.h" |
33 | #include "tensorflow/core/common_runtime/process_util.h" |
34 | #include "tensorflow/core/common_runtime/rendezvous_mgr.h" |
35 | #include "tensorflow/core/common_runtime/rendezvous_util.h" |
36 | #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h" |
37 | #include "tensorflow/core/common_runtime/single_threaded_executor.h" |
38 | #include "tensorflow/core/framework/cancellation.h" |
39 | #include "tensorflow/core/framework/function.h" |
40 | #include "tensorflow/core/framework/graph_to_functiondef.h" |
41 | #include "tensorflow/core/framework/metrics.h" |
42 | #include "tensorflow/core/framework/op_kernel.h" |
43 | #include "tensorflow/core/framework/tensor.h" |
44 | #include "tensorflow/core/framework/types.h" |
45 | #include "tensorflow/core/framework/types.pb.h" |
46 | #include "tensorflow/core/graph/graph.h" |
47 | #include "tensorflow/core/graph/graph_node_util.h" |
48 | #include "tensorflow/core/graph/graph_partition.h" |
49 | #include "tensorflow/core/lib/core/errors.h" |
50 | #include "tensorflow/core/lib/gtl/cleanup.h" |
51 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
52 | #include "tensorflow/core/lib/gtl/map_util.h" |
53 | #include "tensorflow/core/lib/random/random.h" |
54 | #include "tensorflow/core/platform/blocking_counter.h" |
55 | #include "tensorflow/core/platform/notification.h" |
56 | #include "tensorflow/core/util/device_name_utils.h" |
57 | #include "tensorflow/core/util/dump_graph.h" |
58 | #include "tensorflow/core/util/ptr_util.h" |
59 | #include "tensorflow/core/util/reffed_status_callback.h" |
60 | #include "tensorflow/tsl/platform/statusor.h" |
61 | #if !defined(IS_MOBILE_PLATFORM) |
62 | #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" |
63 | #endif // IS_MOBILE_PLATFORM |
64 | |
65 | namespace tensorflow { |
66 | |
67 | const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null" ; |
68 | |
69 | void ProcessFunctionLibraryRuntime::FunctionData::DistributedInit( |
70 | DistributedFunctionLibraryRuntime* parent, const string& function_name, |
71 | const FunctionLibraryDefinition& lib_def, AttrSlice attrs, |
72 | const FunctionLibraryRuntime::InstantiateOptions& options, |
73 | FunctionLibraryRuntime::DoneCallback done) { |
74 | { |
75 | mutex_lock l(mu_); |
76 | is_cross_process_ = true; |
77 | if (init_started_) { |
78 | init_done_.WaitForNotification(); |
79 | done(init_result_); |
80 | return; |
81 | } |
82 | init_started_ = true; |
83 | } |
84 | parent->Instantiate(function_name, lib_def, attrs, options, &local_handle_, |
85 | [this, done](const Status& s) { |
86 | init_done_.Notify(); |
87 | done(s); |
88 | }); |
89 | } |
90 | |
91 | ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( |
92 | const DeviceMgr* device_mgr, Env* env, const ConfigProto* config, |
93 | int graph_def_version, const FunctionLibraryDefinition* lib_def, |
94 | const OptimizerOptions& optimizer_options, |
95 | thread::ThreadPool* default_thread_pool, |
96 | DistributedFunctionLibraryRuntime* parent, |
97 | const SessionMetadata* session_metadata, |
98 | Rendezvous::Factory rendezvous_factory) |
99 | : parent_(parent), |
100 | env_(env), |
101 | config_(config ? absl::make_optional(*config) : absl::nullopt), |
102 | device_mgr_(device_mgr), |
103 | lib_def_(lib_def), |
104 | default_thread_pool_(default_thread_pool), |
105 | flr_map_(new std::unordered_map<Device*, |
106 | std::unique_ptr<FunctionLibraryRuntime>>), |
107 | next_handle_(0), |
108 | session_metadata_(session_metadata), |
109 | rendezvous_factory_(std::move(rendezvous_factory)), |
110 | optimizer_options_(optimizer_options), |
111 | graph_def_version_(graph_def_version) { |
112 | if (device_mgr == nullptr) { |
113 | (*flr_map_)[nullptr] = NewFunctionLibraryRuntime( |
114 | nullptr, env, config_ ? &(*config_) : nullptr, nullptr, |
115 | graph_def_version, lib_def_, default_thread_pool, optimizer_options, |
116 | session_metadata_, this); |
117 | return; |
118 | } |
119 | InitializeDeviceAndFlr(); |
120 | } |
121 | |
122 | /* static */ |
123 | Status ProcessFunctionLibraryRuntime::SendTensors( |
124 | const string& source_device, const string& target_device, |
125 | const string& key_prefix, int64_t src_incarnation, |
126 | gtl::ArraySlice<Tensor> tensors_to_send, DeviceContext* device_context, |
127 | const std::vector<AllocatorAttributes>& alloc_attrs, |
128 | RendezvousInterface* rendezvous) { |
129 | std::vector<string> keys; |
130 | for (int i = 0; i < tensors_to_send.size(); ++i) { |
131 | string name = strings::StrCat(key_prefix, i); |
132 | string key = Rendezvous::CreateKey(source_device, src_incarnation, |
133 | target_device, name, FrameAndIter(0, 0)); |
134 | keys.push_back(key); |
135 | } |
136 | TF_RETURN_IF_ERROR(SendTensorsToRendezvous( |
137 | rendezvous, device_context, alloc_attrs, keys, tensors_to_send)); |
138 | return OkStatus(); |
139 | } |
140 | |
141 | /* static */ |
142 | void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync( |
143 | const string& source_device, const string& target_device, |
144 | const string& key_prefix, int64_t src_incarnation, int64_t num_tensors, |
145 | DeviceContext* device_context, |
146 | const std::vector<AllocatorAttributes>& alloc_attrs, |
147 | RendezvousInterface* rendezvous, std::vector<Tensor>* received_tensors, |
148 | StatusCallback done) { |
149 | std::vector<string> keys; |
150 | for (int64_t i = 0; i < num_tensors; ++i) { |
151 | string name = strings::StrCat(key_prefix, i); |
152 | string key = Rendezvous::CreateKey(source_device, src_incarnation, |
153 | target_device, name, FrameAndIter(0, 0)); |
154 | keys.push_back(key); |
155 | } |
156 | RecvOutputsFromRendezvousAsync(rendezvous, device_context, alloc_attrs, keys, |
157 | received_tensors, std::move(done)); |
158 | } |
159 | |
160 | Status ProcessFunctionLibraryRuntime::GetRetTypes( |
161 | FunctionLibraryRuntime::Handle h, DataTypeVector* ret_types) { |
162 | FunctionLibraryRuntime* flr = nullptr; |
163 | { |
164 | tf_shared_lock l(mu_); |
165 | auto miter = mdevice_data_.find(h); |
166 | if (miter != mdevice_data_.end()) { |
167 | *ret_types = miter->second->ret_types_; |
168 | return OkStatus(); |
169 | } |
170 | auto fiter = function_data_.find(h); |
171 | if (fiter != function_data_.end()) { |
172 | flr = GetFLR(fiter->second->target_device()); |
173 | } |
174 | } |
175 | if (flr != nullptr) { |
176 | return flr->GetRetTypes(h, ret_types); |
177 | } |
178 | return errors::InvalidArgument("Handle " , h, " not found." ); |
179 | } |
180 | |
181 | Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation( |
182 | const string& device_name, int64_t* incarnation) const { |
183 | FunctionLibraryRuntime* flr = GetFLR(device_name); |
184 | if (flr == nullptr) { |
185 | return errors::InvalidArgument("Device name: " , device_name, " not found." ); |
186 | } |
187 | *incarnation = flr->device()->attributes().incarnation(); |
188 | return OkStatus(); |
189 | } |
190 | |
191 | Status ProcessFunctionLibraryRuntime::GetDeviceContext( |
192 | const string& device_name, DeviceContext** device_context) const { |
193 | *device_context = nullptr; |
194 | FunctionLibraryRuntime* flr = GetFLR(device_name); |
195 | if (flr == nullptr) { |
196 | return errors::InvalidArgument("Device name: " , device_name, " not found." ); |
197 | } |
198 | Device* device = flr->device(); |
199 | string device_type = device->parsed_name().type; |
200 | if (device_type == "CPU" || device_type == "TPU_SYSTEM" ) { |
201 | // "TPU_SYSTEM" indicates that `device` is a CPU. |
202 | return OkStatus(); |
203 | } |
204 | |
205 | if (device->IsRemoteCallAllowed()) { |
206 | auto* dev_info = flr->device()->tensorflow_accelerator_device_info(); |
207 | if (dev_info) { |
208 | *device_context = dev_info->default_context; |
209 | return OkStatus(); |
210 | } |
211 | } |
212 | |
213 | return errors::Internal("Device type: " , device_type, |
214 | " is currently unsupported for remote " , |
215 | "function executions" ); |
216 | } |
217 | |
218 | void ProcessFunctionLibraryRuntime::InitializeDeviceAndFlr() { |
219 | // Reset device_set_ by one of the two following scenarios: |
220 | // 1) Both cluster-FLR and its remote_device_mgr is available: include local |
221 | // devices (if any) from the local device_mgr_ as Device type, and include |
222 | // remote devices from cluster's remote_device_mgr as RemoteDevice type. |
223 | // 2) Include local devices from the local device_mgr_. |
224 | // In both scenarios, no device is added more than one times. |
225 | mutex_lock l(mu_); |
226 | device_set_ = std::make_shared<DeviceSet>(); |
227 | if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) { |
228 | for (auto d : parent_->remote_device_mgr()->ListDevices()) { |
229 | Device* device = nullptr; |
230 | if (device_mgr_->LookupDevice(d->name(), &device) == OkStatus()) { |
231 | // If this device exists in device_mgr, i.e., a local device, |
232 | // add this device from the instance included in device_mgr_ |
233 | device_set_->AddDevice(device); |
234 | } else { |
235 | device_set_->AddDevice(d); |
236 | } |
237 | } |
238 | } else { |
239 | for (auto d : device_mgr_->ListDevices()) { |
240 | device_set_->AddDevice(d); |
241 | } |
242 | } |
243 | |
244 | // Update flr_map_ by adding new devices |
245 | for (Device* d : device_mgr_->ListDevices()) { |
246 | if ((*flr_map_)[d] == nullptr) { |
247 | (*flr_map_)[d] = NewFunctionLibraryRuntime( |
248 | device_mgr_, env_, config_ ? &(*config_) : nullptr, d, |
249 | graph_def_version_, lib_def_, default_thread_pool_, |
250 | optimizer_options_, session_metadata_, this); |
251 | } |
252 | } |
253 | } |
254 | |
255 | FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR( |
256 | const string& device_name) const { |
257 | Device* device = nullptr; |
258 | if (device_name != kDefaultFLRDevice) { |
259 | if (!device_mgr_->LookupDevice(device_name, &device).ok()) { |
260 | VLOG(4) << "Could not find device: " << device_name; |
261 | return nullptr; |
262 | } |
263 | } |
264 | const auto& iter = flr_map_->find(device); |
265 | if (iter == flr_map_->end()) { |
266 | VLOG(1) << "Could not find device: " << device_name |
267 | << "in the local process." ; |
268 | return nullptr; |
269 | } |
270 | return iter->second.get(); |
271 | } |
272 | |
273 | FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle( |
274 | const string& function_key, const string& device_name, |
275 | FunctionLibraryRuntime::LocalHandle local_handle) { |
276 | mutex_lock l(mu_); |
277 | return AddHandleLocked(function_key, device_name, local_handle); |
278 | } |
279 | |
280 | FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandleLocked( |
281 | const string& function_key, const string& device_name, |
282 | FunctionLibraryRuntime::LocalHandle local_handle) { |
283 | auto h = next_handle_; |
284 | function_data_[h] = |
285 | std::make_unique<FunctionData>(device_name, local_handle, function_key); |
286 | table_[function_key] = h; |
287 | next_handle_++; |
288 | return h; |
289 | } |
290 | |
291 | FunctionLibraryRuntime::Handle |
292 | ProcessFunctionLibraryRuntime::AddMultiDeviceHandle( |
293 | std::unique_ptr<MultiDeviceFunctionData> data, const string& function_key) { |
294 | mutex_lock l(mu_); |
295 | auto h = next_handle_; |
296 | mdevice_data_[h] = std::move(data); |
297 | table_[function_key] = h; |
298 | next_handle_++; |
299 | return h; |
300 | } |
301 | |
302 | bool ProcessFunctionLibraryRuntime::HasMultiDeviceHandle( |
303 | FunctionLibraryRuntime::Handle handle) const { |
304 | bool multi_device; |
305 | { |
306 | tf_shared_lock l(mu_); |
307 | multi_device = mdevice_data_.find(handle) != mdevice_data_.end(); |
308 | } |
309 | return multi_device; |
310 | } |
311 | |
312 | FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle( |
313 | const string& function_key) const { |
314 | tf_shared_lock l(mu_); |
315 | return gtl::FindWithDefault(table_, function_key, kInvalidHandle); |
316 | } |
317 | |
318 | FunctionLibraryRuntime::LocalHandle |
319 | ProcessFunctionLibraryRuntime::GetHandleOnDevice( |
320 | const string& device_name, FunctionLibraryRuntime::Handle handle, |
321 | bool include_multi_device) const { |
322 | tf_shared_lock l(mu_); |
323 | |
324 | auto miter = mdevice_data_.find(handle); |
325 | if (miter != mdevice_data_.end()) { |
326 | if (!include_multi_device) return kInvalidLocalHandle; |
327 | |
328 | const MultiDeviceFunctionData& data = *miter->second; |
329 | if (data.glue_.size() != 1) return kInvalidLocalHandle; |
330 | |
331 | const auto& pair = *data.glue_.begin(); |
332 | const string& func_device_name = pair.first; |
333 | const ComponentFunctionData& component_data = pair.second; |
334 | if (func_device_name != device_name) return kInvalidLocalHandle; |
335 | |
336 | // Replace the given handle with the handle for the single component |
337 | // function. |
338 | handle = component_data.handle; |
339 | } |
340 | |
341 | auto iter = function_data_.find(handle); |
342 | if (iter == function_data_.end()) { |
343 | return kInvalidLocalHandle; |
344 | } |
345 | FunctionData* function_data = iter->second.get(); |
346 | if (function_data->target_device() != device_name) { |
347 | return kInvalidLocalHandle; |
348 | } |
349 | return function_data->local_handle(); |
350 | } |
351 | |
352 | string ProcessFunctionLibraryRuntime::GetDeviceName( |
353 | FunctionLibraryRuntime::Handle handle) const { |
354 | tf_shared_lock l(mu_); |
355 | auto iter = function_data_.find(handle); |
356 | CHECK(iter != function_data_.end()); |
357 | FunctionData* function_data = iter->second.get(); |
358 | return function_data->target_device(); |
359 | } |
360 | |
361 | ProcessFunctionLibraryRuntime::MultiDeviceFunctionData* |
362 | ProcessFunctionLibraryRuntime::IsMultiDevice( |
363 | FunctionLibraryRuntime::Handle handle) const { |
364 | tf_shared_lock l(mu_); |
365 | const auto& it = mdevice_data_.find(handle); |
366 | if (it != mdevice_data_.end()) { |
367 | return it->second.get(); |
368 | } |
369 | return nullptr; |
370 | } |
371 | |
372 | namespace { |
373 | // Sets `group` to the first colocation group specified in `node`. If no |
374 | // group is specified, does not touch `group`. |
375 | void GetColocationGroup(const Node* node, string* group) { |
376 | // We hoist the conversion from C-style string literal to string here, |
377 | // so that we can avoid the many repeated calls to strlen(). |
378 | static const StringPiece kColocationAttrNameStringPiece(kColocationAttrName); |
379 | const AttrValue* attr_value = |
380 | node->attrs().Find(kColocationAttrNameStringPiece); |
381 | if (attr_value != nullptr && attr_value->has_list() && |
382 | attr_value->list().s_size() > 0) { |
383 | *group = attr_value->list().s(0); |
384 | } |
385 | } |
386 | |
387 | const string* AssignedOrRequestedDeviceName(const Node& node) { |
388 | if (node.has_assigned_device_name()) { |
389 | return &node.assigned_device_name(); |
390 | } |
391 | return &node.requested_device(); |
392 | } |
393 | |
394 | Status SetArgShape(const std::unordered_map<int, DtypeAndPartialTensorShape>& |
395 | input_resource_dtypes_and_shapes, |
396 | const std::vector<Node*>& arg_nodes) { |
397 | for (Node* n : arg_nodes) { |
398 | int index; |
399 | TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index" , &index)); |
400 | DataType dtype; |
401 | TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T" , &dtype)); |
402 | if (dtype == DT_RESOURCE) { |
403 | auto dtype_and_shape_iter = input_resource_dtypes_and_shapes.find(index); |
404 | if (dtype_and_shape_iter != input_resource_dtypes_and_shapes.end()) { |
405 | AttrValue dtype_attr_value; |
406 | dtype_attr_value.mutable_list()->add_type( |
407 | dtype_and_shape_iter->second.dtype); |
408 | n->AddAttr("_handle_dtypes" , dtype_attr_value); |
409 | TensorShapeProto shape_proto; |
410 | dtype_and_shape_iter->second.shape.AsProto(&shape_proto); |
411 | AttrValue shape_attr_value; |
412 | *shape_attr_value.mutable_list()->add_shape() = shape_proto; |
413 | n->AddAttr("_handle_shapes" , shape_attr_value); |
414 | } |
415 | } |
416 | } |
417 | return OkStatus(); |
418 | } |
419 | |
420 | // Returns the local tensors referred by `args`. |
421 | std::vector<Tensor> GetLocalArgs(gtl::ArraySlice<FunctionArg> args) { |
422 | std::vector<Tensor> tensors; |
423 | for (const auto& arg : args) { |
424 | if (arg.index() == 0) { |
425 | tensors.push_back(absl::get<Tensor>(arg)); |
426 | } |
427 | } |
428 | return tensors; |
429 | } |
430 | |
431 | // Update the done callback to push Tensors in `tensors` into `rets`. |
432 | FunctionLibraryRuntime::DoneCallback TensorsToFunctionRetsDoneCallback( |
433 | std::vector<FunctionRet>* rets, std::vector<Tensor>* tensors, |
434 | FunctionLibraryRuntime::DoneCallback done) { |
435 | return [rets, tensors, done = std::move(done)](const Status& s) { |
436 | if (s.ok()) { |
437 | for (const auto& t : *tensors) { |
438 | rets->push_back(t); |
439 | } |
440 | } |
441 | delete tensors; |
442 | done(s); |
443 | }; |
444 | } |
445 | |
446 | // Push Tensors in `function_rets` into `tensors`. |
447 | Status FunctionRetsToTensors(const std::vector<FunctionRet>* function_rets, |
448 | std::vector<Tensor>* tensors) { |
449 | for (const auto& ret : *function_rets) { |
450 | if (ret.index() != 0) { |
451 | return errors::Internal( |
452 | "Expect a Tensor as a function output but got a TensorShape." ); |
453 | } |
454 | tensors->push_back(absl::get<Tensor>(ret)); |
455 | } |
456 | return OkStatus(); |
457 | } |
458 | |
459 | } // anonymous namespace |
460 | |
461 | Status ProcessFunctionLibraryRuntime::PinArgsAndRets( |
462 | const std::vector<string>& input_devices, |
463 | const std::vector<string>& output_devices, const DeviceSet& device_set, |
464 | const std::vector<Node*>& arg_nodes, const std::vector<Node*>& ret_nodes, |
465 | const FunctionLibraryDefinition* lib_def, Device* default_device) { |
466 | // If output_devices are not specified, we want to set the output device |
467 | // based on the device of the output producing node. The output producing |
468 | // node can be an arg node because functions can simply return their |
469 | // arguments. To make sure that the output producing nodes have assigned |
470 | // devices, we assign them to arguments first. |
471 | for (Node* node : arg_nodes) { |
472 | const AttrValue* attr_value; |
473 | TF_RETURN_IF_ERROR(node->attrs().Find("index" , &attr_value)); |
474 | int64_t index = attr_value->i(); |
475 | node->set_assigned_device_name(input_devices[index]); |
476 | } |
477 | |
478 | for (Node* node : ret_nodes) { |
479 | if (output_devices.empty()) { |
480 | DataType dtype; |
481 | TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T" , &dtype)); |
482 | |
483 | VLOG(3) << "Trying to determine device for node " << node->name() |
484 | << "[T=" << DataTypeString(dtype) << "]" ; |
485 | |
486 | // If output_devices are empty, the node producing retval |
487 | // must have explicitly assigned device or a colocation constraint |
488 | // to a node with explicitly assigned device. |
489 | for (const auto& it : node->in_edges()) { |
490 | if (it->IsControlEdge()) continue; |
491 | |
492 | Node* src_node = it->src(); |
493 | const string* src_device = AssignedOrRequestedDeviceName(*src_node); |
494 | string colocation_group = "" ; |
495 | GetColocationGroup(src_node, &colocation_group); |
496 | VLOG(3) << "Considering src: " << src_node->name() |
497 | << " src_device: " << *src_device |
498 | << " colo group: " << colocation_group; |
499 | while (src_device->empty() && colocation_group.empty() && |
500 | src_node->IsIdentity()) { |
501 | // Only follows the real data input of Identity, not control edges. |
502 | Node* input_node; |
503 | TF_RETURN_IF_ERROR(src_node->input_node(0, &input_node)); |
504 | src_node = input_node; |
505 | |
506 | src_device = AssignedOrRequestedDeviceName(*src_node); |
507 | GetColocationGroup(src_node, &colocation_group); |
508 | VLOG(3) << "Considering src: " << src_node->name() |
509 | << " src_device: " << *src_device |
510 | << " colo group: " << colocation_group; |
511 | } |
512 | |
513 | // If resource is produced by a function call node, we can't trust |
514 | // source node device assignment, because multi-device functions can |
515 | // return resource placed on multiple devices. In such case we leave |
516 | // retval device assignment empty, and rely on placer to infer correct |
517 | // assignment based on actual output device. |
518 | const bool can_use_src_node_device = |
519 | !(dtype == DT_RESOURCE && IsFunctionCall(*lib_def, *src_node)); |
520 | |
521 | if (!colocation_group.empty()) { |
522 | AttrValue::ListValue colo_attr; |
523 | colo_attr.add_s(colocation_group); |
524 | std::vector<string> colo_slice = {colocation_group}; |
525 | node->AddAttr(kColocationAttrName, colo_slice); |
526 | } else if (!src_device->empty() && can_use_src_node_device) { |
527 | // Do not copy device from src node for variants, unless it is a no-op |
528 | // forward from input to output. This gets handled in |
529 | // colocation_graph.cc which has special logic for correctly placing |
530 | // _Retvals for various variant types. |
531 | if (dtype == DT_VARIANT && !src_node->IsArg()) { |
532 | continue; |
533 | } |
534 | // src_device can be a partially specified device. Find the |
535 | // matching device in the device_set. |
536 | DeviceNameUtils::ParsedName parsed; |
537 | if (!DeviceNameUtils::ParseFullName(*src_device, &parsed)) { |
538 | return errors::InvalidArgument( |
539 | "Failed to parse explicit device specification " , *src_device); |
540 | } |
541 | std::vector<Device*> matching_devices; |
542 | device_set.FindMatchingDevices(parsed, &matching_devices); |
543 | if (matching_devices.empty()) { |
544 | if (default_device != nullptr) { |
545 | matching_devices.push_back(default_device); |
546 | } else { |
547 | return errors::InvalidArgument( |
548 | "Unable to find any devices for spec " , *src_device); |
549 | } |
550 | } else if (matching_devices.size() != 1) { |
551 | bool on_same_task = true; |
552 | for (int i = 1; i < matching_devices.size(); ++i) { |
553 | if (!DeviceNameUtils::IsSameAddressSpace( |
554 | matching_devices.at(0)->parsed_name(), |
555 | matching_devices.at(i)->parsed_name())) { |
556 | on_same_task = false; |
557 | break; |
558 | } |
559 | } |
560 | // If the src node of an output is assigned to a address space (e.g. |
561 | // py_func), rely on placer to assign a device to the output. |
562 | if (on_same_task) { |
563 | continue; |
564 | } |
565 | // Compare with default_device if it has a narrower scope matching |
566 | // requested device. |
567 | if (default_device != nullptr) { |
568 | int colocated_on_default_device = 0; |
569 | for (int i = 0; i < matching_devices.size(); ++i) { |
570 | if (DeviceNameUtils::IsSameAddressSpace( |
571 | default_device->parsed_name(), |
572 | matching_devices.at(i)->parsed_name())) { |
573 | colocated_on_default_device++; |
574 | } |
575 | } |
576 | // Continue to raise error if multiple colocated devices are |
577 | // found. |
578 | if (colocated_on_default_device == 1) { |
579 | continue; |
580 | } |
581 | } |
582 | // Convert a vector of devices to a string. |
583 | // Using absl::StrJoin did not work in Android builds. |
584 | string devices = "[" ; |
585 | for (Device* device : matching_devices) { |
586 | devices.append(device->name()); |
587 | devices.append(", " ); |
588 | } |
589 | if (devices.size() > 2) { |
590 | devices.resize(devices.size() - 2); |
591 | } |
592 | devices.append("]" ); |
593 | |
594 | return errors::InvalidArgument( |
595 | *src_device, |
596 | "When FunctionLibraryRuntime::Options.output_devices are " |
597 | "not specified for a multi-device function, the device " |
598 | "specification on the output node must match exactly one " |
599 | "device. Matched devices are " , |
600 | devices); |
601 | } |
602 | VLOG(3) << "Setting output device to " << matching_devices[0]->name() |
603 | << " for node " << SummarizeNode(*node); |
604 | node->set_assigned_device_name(matching_devices[0]->name()); |
605 | } else if (!src_device->empty() && !can_use_src_node_device) { |
606 | VLOG(3) << "Did not set device for a resource output node " |
607 | << SummarizeNode(*node); |
608 | } |
609 | } |
610 | } else { |
611 | const AttrValue* attr_value; |
612 | TF_RETURN_IF_ERROR(node->attrs().Find("index" , &attr_value)); |
613 | int64_t index = attr_value->i(); |
614 | // output_devices size is checked in InstantiateMultiDevice |
615 | DCHECK_GT(output_devices.size(), index); |
616 | VLOG(3) << "Setting output device to " << output_devices[index] |
617 | << " for return at index " << index; |
618 | node->set_assigned_device_name(output_devices[index]); |
619 | } |
620 | } |
621 | return OkStatus(); |
622 | } |
623 | |
624 | namespace { |
625 | |
626 | Status ValidateNoListArguments( |
627 | const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, const char* arg_type, |
628 | const string& function_name) { |
629 | for (const OpDef::ArgDef& arg : args) { |
630 | if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) { |
631 | return errors::InvalidArgument( |
632 | "Function " , function_name, " has an " , arg_type, " named \"" , |
633 | arg.name(), |
634 | "\" that is a list of tensors." |
635 | " Multi-device functions support only single-tensor inputs " |
636 | " and outputs" ); |
637 | } |
638 | } |
639 | return OkStatus(); |
640 | } |
641 | |
642 | Status ValidateMultiDeviceOptions( |
643 | const FunctionDef& fdef, |
644 | const FunctionLibraryRuntime::InstantiateOptions& options) { |
645 | const OpDef& signature = fdef.signature(); |
646 | // Multi-device functions currently do not support list inputs or outputs. |
647 | TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.input_arg(), "input" , |
648 | signature.name())); |
649 | TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.output_arg(), "output" , |
650 | signature.name())); |
651 | if (fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 && |
652 | fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b()) { |
653 | return errors::Unimplemented( |
654 | "Function '" , signature.name(), "' has `" , |
655 | FunctionLibraryDefinition::kIntsOnDeviceAttr, |
656 | "` attribute set. This attribute is not currently supported by " |
657 | "multi-device functions." ); |
658 | } |
659 | if (options.input_devices.size() != signature.input_arg_size()) { |
660 | return errors::InvalidArgument( |
661 | "InstantiateOptions.input_devices must have the same length " |
662 | "as the number of arguments: input_devices length = " , |
663 | options.input_devices.size(), |
664 | " number of arguments = " , signature.input_arg_size()); |
665 | } |
666 | if (!options.output_devices.empty() && |
667 | options.output_devices.size() != signature.output_arg_size()) { |
668 | return errors::InvalidArgument( |
669 | "InstantiateOptions.output_devices must either be empty or have the " |
670 | "same length as the number of arguments: output_devices length = " , |
671 | options.output_devices.size(), |
672 | " number of arguments = " , signature.output_arg_size()); |
673 | } |
674 | return OkStatus(); |
675 | } |
676 | |
677 | } // anonymous namespace |
678 | |
679 | ProcessFunctionLibraryRuntime::AsyncAttributes::Summary |
680 | ProcessFunctionLibraryRuntime::AsyncAttributes::Summarize(const Graph* graph) { |
681 | bool has_send_op = false; |
682 | bool has_recv_op = false; |
683 | bool has_unsafe_op = false; |
684 | for (const Node* node : graph->nodes()) { |
685 | if (node->IsSend() || node->IsHostSend()) { |
686 | has_send_op = true; |
687 | } |
688 | if (node->IsRecv() || node->IsHostRecv()) { |
689 | has_recv_op = true; |
690 | } |
691 | if (!ValidateOpIsSafeForSyncExecution(*node, |
692 | allow_control_flow_sync_execution()) |
693 | .ok()) { |
694 | has_unsafe_op = true; |
695 | } |
696 | } |
697 | // (1) Anything completely unsupported? |
698 | if (has_unsafe_op) { |
699 | metrics::IncrementTestCounter("subgraph_async_summary" , "unsafe_op" ); |
700 | return AsyncAttributes::kAsyncRequired; |
701 | } |
702 | // (2) That only leaves send/recv. If neither, then it's safe. |
703 | if (!has_send_op && !has_recv_op) { |
704 | metrics::IncrementTestCounter("subgraph_async_summary" , "safe_for_sync" ); |
705 | return AsyncAttributes::kSafeForSync; |
706 | } |
707 | // (3) If each subgraph has only send or only recv, then it's possible to |
708 | // order them to run sequentially without deadlock. |
709 | if (has_send_op && !has_recv_op) { |
710 | metrics::IncrementTestCounter("subgraph_async_summary" , "send_only" ); |
711 | return AsyncAttributes::kSendOnly; |
712 | } |
713 | if (has_recv_op && !has_send_op) { |
714 | metrics::IncrementTestCounter("subgraph_async_summary" , "recv_only" ); |
715 | return AsyncAttributes::kRecvOnly; |
716 | } |
717 | // Otherwise, assume it's unsupported. |
718 | metrics::IncrementTestCounter("subgraph_async_summary" , "other" ); |
719 | return AsyncAttributes::kAsyncRequired; |
720 | } |
721 | |
722 | Status GetGraphAndArgRets( |
723 | const string& function_name, AttrSlice attrs, const FunctionDef* fdef, |
724 | const FunctionLibraryDefinition* lib_def, std::unique_ptr<Graph>* graph, |
725 | std::vector<Node*>* arg_nodes, std::vector<Node*>* ret_nodes, |
726 | std::vector<string>* ret_node_names, DataTypeVector* ret_types, |
727 | std::vector<string>* control_ret_node_names) { |
728 | std::unique_ptr<FunctionBody> fbody; |
729 | // TODO(iga): FunctionDefToBodyHelper copies fdef. Avoid this copy. |
730 | TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, attrs, lib_def, &fbody)); |
731 | if (!fbody) { |
732 | LOG(ERROR) << "Failed to get FunctionBody for \"" << function_name << "\"" ; |
733 | return errors::Internal("Failed to construct FunctionBody for " , |
734 | function_name); |
735 | } |
736 | *graph = std::unique_ptr<Graph>(fbody->graph); |
737 | arg_nodes->reserve(fbody->arg_nodes.size()); |
738 | std::copy(fbody->arg_nodes.begin(), fbody->arg_nodes.end(), |
739 | std::back_inserter(*arg_nodes)); |
740 | ret_nodes->reserve(fbody->ret_nodes.size()); |
741 | std::copy(fbody->ret_nodes.begin(), fbody->ret_nodes.end(), |
742 | std::back_inserter(*ret_nodes)); |
743 | fbody->graph = nullptr; |
744 | ret_node_names->reserve(fbody->ret_nodes.size()); |
745 | for (const Node* node : fbody->ret_nodes) { |
746 | ret_node_names->push_back(node->name()); |
747 | } |
748 | for (const auto& ret_type : fbody->ret_types) { |
749 | ret_types->push_back(ret_type); |
750 | } |
751 | control_ret_node_names->reserve(fbody->control_ret_nodes.size()); |
752 | for (const Node* node : fbody->control_ret_nodes) { |
753 | control_ret_node_names->push_back(node->name()); |
754 | } |
755 | return OkStatus(); |
756 | } |
757 | |
758 | StatusOr<ProcessFunctionLibraryRuntime::OptimizedFunctionGraphInfo> |
759 | ProcessFunctionLibraryRuntime::OptimizeFunctionGraph( |
760 | const string& function_name, AttrSlice attrs, |
761 | const FunctionLibraryRuntime::InstantiateOptions& options, |
762 | const std::shared_ptr<DeviceSet>& dev_set) { |
763 | const FunctionLibraryDefinition* lib_def = |
764 | options.lib_def == nullptr ? lib_def_ : options.lib_def; |
765 | |
766 | const FunctionDef* fdef = lib_def->Find(function_name); |
767 | if (fdef == nullptr) { |
768 | return errors::InvalidArgument("Failed to find function \"" , function_name, |
769 | "\" in function library: " , lib_def); |
770 | } |
771 | |
772 | TF_RETURN_IF_ERROR(ValidateMultiDeviceOptions(*fdef, options)); |
773 | |
774 | std::unique_ptr<Graph> graph; |
775 | std::vector<Node*> arg_nodes, ret_nodes; |
776 | std::vector<string> ret_node_names; |
777 | DataTypeVector ret_types; |
778 | std::vector<string> control_ret_node_names; |
779 | |
780 | TF_RETURN_IF_ERROR(GetGraphAndArgRets( |
781 | function_name, attrs, fdef, lib_def, &graph, &arg_nodes, &ret_nodes, |
782 | &ret_node_names, &ret_types, &control_ret_node_names)); |
783 | |
784 | GraphDef graph_def; |
785 | graph->ToGraphDef(&graph_def); |
786 | FunctionLibraryDefinition reachable_lib_def = |
787 | lib_def->ReachableDefinitions(graph_def); |
788 | *graph_def.mutable_library() = reachable_lib_def.ToProto(); |
789 | if (options.graph_collector != nullptr) { |
790 | options.graph_collector->CollectRawGraph(graph_def); |
791 | } |
792 | |
793 | Device* default_device = nullptr; |
794 | if (options.default_device_to_target && !options.target.empty()) { |
795 | // Make the `target` device the default device if nothing else is hard |
796 | // coded. This allows the same function definition to be specialized to |
797 | // different devices depending on the `PartitionedCallOp` device. |
798 | FunctionLibraryRuntime* flr = GetFLR(options.target); |
799 | if (flr == nullptr) { |
800 | return errors::InvalidArgument( |
801 | "Cannot instantiate multi-device function with target device " , |
802 | options.target); |
803 | } |
804 | default_device = flr->device(); |
805 | } |
806 | |
807 | // Mark and assign device for each node in the graph to be compiled by |
808 | // specified device. |
809 | if (!options.xla_compile_device_type.empty()) { |
810 | for (Node* node : graph->op_nodes()) { |
811 | node->AddAttr("_xla_compile_device_type" , |
812 | options.xla_compile_device_type); |
813 | if (default_device) { |
814 | node->set_assigned_device_name(default_device->name()); |
815 | } |
816 | } |
817 | } |
818 | |
819 | TF_RETURN_IF_ERROR( |
820 | SetArgShape(options.input_resource_dtypes_and_shapes, arg_nodes)); |
821 | TF_RETURN_IF_ERROR(PinArgsAndRets( |
822 | options.input_devices, options.output_devices, *dev_set, arg_nodes, |
823 | ret_nodes, lib_def_, |
824 | options.config_proto.allow_soft_placement() ? default_device : nullptr)); |
825 | |
826 | // The runtime shouldn't depend on duplication between the function library |
827 | // owned by the graph and the one owned by the runtime. To ensure this, for |
828 | // now we ensure that the graph function library is empty and the runtime |
829 | // library receives the query from LookUps on the graph function library. |
830 | graph->mutable_flib_def()->set_default_registry(&reachable_lib_def); |
831 | graph->mutable_flib_def()->Clear(); |
832 | |
833 | // Do not run function/graph optimization passes for component functions, |
834 | // since they have already processed the main function. |
835 | const bool should_run_optimization_passes = !options.is_component_function; |
836 | if (!should_run_optimization_passes) { |
837 | VLOG(1) << "Skipping function/graph optimization passes when instantiating " |
838 | "component function " |
839 | << function_name; |
840 | } |
841 | |
842 | // Mapping from a function body node name to the control output name. |
843 | std::unordered_map<string, string> node_name_to_control_ret; |
844 | |
845 | bool control_rets_updated = false; |
846 | if (should_run_optimization_passes) { |
847 | TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run( |
848 | *dev_set, options.config_proto, &graph, &reachable_lib_def, |
849 | &control_ret_node_names, &control_rets_updated)); |
850 | } |
851 | |
852 | if (control_rets_updated) { |
853 | // Function graph pass may have resulted in different nodes/node names for |
854 | // control rets. |
855 | for (const auto& control_ret : control_ret_node_names) { |
856 | node_name_to_control_ret.emplace(control_ret, control_ret); |
857 | } |
858 | } else { |
859 | for (const auto& control_ret : fdef->control_ret()) { |
860 | node_name_to_control_ret.emplace(control_ret.second, control_ret.first); |
861 | } |
862 | } |
863 | |
864 | GraphOptimizationPassOptions optimization_options; |
865 | // TODO(iga): Thread other relevant options from SessionOptions. |
866 | SessionOptions session_options; |
867 | session_options.env = env_; |
868 | session_options.config = options.config_proto; |
869 | optimization_options.session_options = &session_options; |
870 | optimization_options.graph = &graph; |
871 | optimization_options.flib_def = &reachable_lib_def; |
872 | optimization_options.device_set = dev_set.get(); |
873 | optimization_options.is_function_graph = true; |
874 | std::vector<CompositeDevice*> composite_devices; |
875 | { |
876 | tf_shared_lock l(mu_); |
877 | for (auto* d : composite_devices_) composite_devices.push_back(d); |
878 | } |
879 | optimization_options.composite_devices = &composite_devices; |
880 | optimization_options.default_function_device = default_device; |
881 | optimization_options.function_def = fdef; |
882 | optimization_options.shape_inference_on_tfe_dialect_import = |
883 | options.shape_inference_on_tfe_dialect_import; |
884 | |
885 | DumpGraph("Before running PRE_PLACEMENT passes" , graph.get()); |
886 | if (should_run_optimization_passes) { |
887 | TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( |
888 | OptimizationPassRegistry::PRE_PLACEMENT, optimization_options)); |
889 | } |
890 | |
891 | // TODO(b/124993244): Smartly merge options in nested defuns, and raise |
892 | // exceptions/warnings in case where nested function call options are ignored. |
893 | DumpGraph("Before calling Placer" , graph.get()); |
894 | Placer placer(graph.get(), function_name, optimization_options.flib_def, |
895 | dev_set.get(), default_device, |
896 | options.config_proto.allow_soft_placement(), |
897 | options.config_proto.log_device_placement()); |
898 | TF_RETURN_IF_ERROR(placer.Run()); |
899 | |
900 | DumpGraph("Before running POST_PLACEMENT passes" , graph.get()); |
901 | if (should_run_optimization_passes) { |
902 | TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( |
903 | OptimizationPassRegistry::POST_PLACEMENT, optimization_options)); |
904 | } |
905 | |
906 | Device* cpu_device; |
907 | TF_RETURN_IF_ERROR(device_mgr_->LookupDevice("CPU:0" , &cpu_device)); |
908 | |
909 | if (options.optimize_graph_fn) { |
910 | DumpGraph("Before running graph optimization fn" , graph.get()); |
911 | Status status = options.optimize_graph_fn( |
912 | std::move(ret_node_names), std::move(control_ret_node_names), |
913 | &reachable_lib_def, *dev_set, cpu_device, &graph); |
914 | if (!status.ok()) { |
915 | LOG(WARNING) << "Ignoring multi-device function optimization failure: " |
916 | << status.ToString(); |
917 | } |
918 | DumpGraph("After optimization" , graph.get()); |
919 | } |
920 | |
921 | DumpGraph("Before running POST_REWRITE_FOR_EXEC passes" , graph.get()); |
922 | if (should_run_optimization_passes) { |
923 | TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( |
924 | OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options)); |
925 | } |
926 | |
927 | graph->mutable_flib_def()->set_default_registry(nullptr); |
928 | graph->mutable_flib_def()->Clear(); |
929 | return OptimizedFunctionGraphInfo{ |
930 | std::move(graph), std::move(reachable_lib_def), node_name_to_control_ret, |
931 | std::move(ret_types), ret_nodes.size()}; |
932 | } |
933 | |
934 | Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( |
935 | const string& function_name, AttrSlice attrs, |
936 | const FunctionLibraryRuntime::InstantiateOptions& options, |
937 | FunctionLibraryRuntime::Handle* handle) { |
938 | // Check if this function has already been instantiated. |
939 | const string& function_key = Canonicalize(function_name, attrs, options); |
940 | |
941 | { |
942 | mutex_lock l(mu_); |
943 | const auto& it = table_.find(function_key); |
944 | if (it != table_.end()) { |
945 | *handle = it->second; |
946 | ++mdevice_data_[*handle]->instantiation_counter_; |
947 | return OkStatus(); |
948 | } |
949 | } |
950 | |
951 | VLOG(1) << "Instantiating MultiDevice function \"" << function_name |
952 | << "\" on default device \"" << options.target << "\"" ; |
953 | if (VLOG_IS_ON(3)) { |
954 | int index = 0; |
955 | VLOG(3) << "Requested input devices:" ; |
956 | for (const string& device : options.input_devices) { |
957 | VLOG(3) << " [input " << index++ << "] " << device; |
958 | } |
959 | index = 0; |
960 | VLOG(3) << "Requested output devices:" ; |
961 | for (const string& device : options.output_devices) { |
962 | VLOG(3) << " [output " << index++ << "] " << device; |
963 | } |
964 | } |
965 | |
966 | const std::shared_ptr<DeviceSet> dev_set = device_set(); |
967 | const uint64 optimization_start_time_usecs = Env::Default()->NowMicros(); |
968 | TF_ASSIGN_OR_RETURN( |
969 | auto optimized_graph_info, |
970 | OptimizeFunctionGraph(function_name, attrs, options, dev_set)); |
971 | |
972 | auto& graph = optimized_graph_info.graph; |
973 | graph->mutable_flib_def()->set_default_registry( |
974 | &(optimized_graph_info.lib_def)); |
975 | |
976 | // Expand the nodes assigned to a CompositeDevice before graph partition to |
977 | // avoid generating a subgraph on a virtual device for execution. |
978 | // This transformation should happen as late as possible, in order to run as |
979 | // more graph optimization passes (e.g. PRE_PLACEMENT, PLACER, |
980 | // POST_PLACEMENT, POST_REWRITE_FOR_EXEC) on a smaller graph as possible. |
981 | TF_RETURN_IF_ERROR(ReplicatePerReplicaNodesInFunctionGraph( |
982 | options.composite_devices, graph.get())); |
983 | |
984 | const FunctionLibraryDefinition* lib_def = |
985 | options.lib_def == nullptr ? lib_def_ : options.lib_def; |
986 | if (options.graph_collector != nullptr) { |
987 | GraphDef def; |
988 | graph->ToGraphDef(&def); |
989 | *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto(); |
990 | options.graph_collector->CollectOptimizedGraph(def); |
991 | } |
992 | |
993 | VLOG(4) << "Main function graph to be partitioned:" ; |
994 | VLOG(4) << DebugString(graph->ToGraphDefDebug()); |
995 | |
996 | std::unordered_map<string, std::unique_ptr<Graph>> subgraphs; |
997 | TF_RETURN_IF_ERROR( |
998 | PartitionFunctionGraph(*dev_set, std::move(graph), &subgraphs)); |
999 | |
1000 | for (const auto& pair : subgraphs) { |
1001 | DumpGraph(strings::StrCat("Before running POST_PARTITIONING passes (" , |
1002 | pair.first, ")" ), |
1003 | pair.second.get()); |
1004 | } |
1005 | |
1006 | GraphOptimizationPassOptions optimization_options; |
1007 | optimization_options.flib_def = &(optimized_graph_info.lib_def); |
1008 | optimization_options.is_function_graph = true; |
1009 | optimization_options.graph = nullptr; |
1010 | optimization_options.device_set = nullptr; |
1011 | optimization_options.partition_graphs = &subgraphs; |
1012 | // Normally POST_PARTITIONING passes are run by distributed workers. |
1013 | // Distributed workers are currently not supported in this code path, so we |
1014 | // run the passes here. |
1015 | const bool should_run_optimization_passes = !options.is_component_function; |
1016 | if (should_run_optimization_passes) { |
1017 | TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( |
1018 | OptimizationPassRegistry::POST_PARTITIONING, optimization_options)); |
1019 | } |
1020 | |
1021 | for (const auto& pair : subgraphs) { |
1022 | const auto* optimized_subgraph = pair.second.get(); |
1023 | DumpGraph( |
1024 | strings::StrCat("After all optimization passes (" , pair.first, ")" ), |
1025 | optimized_subgraph); |
1026 | if (VLOG_IS_ON(1)) { |
1027 | DumpGraphDefToFile( |
1028 | strings::StrCat("pflr_after_all_optimization_passes_" , |
1029 | reinterpret_cast<uintptr_t>(optimized_subgraph), "_" , |
1030 | pair.first), |
1031 | optimized_subgraph->ToGraphDefDebug()); |
1032 | } |
1033 | } |
1034 | const uint64 optimization_end_time_usecs = Env::Default()->NowMicros(); |
1035 | metrics::UpdateFunctionGraphOptimizationTime(optimization_end_time_usecs - |
1036 | optimization_start_time_usecs); |
1037 | |
1038 | if (options.graph_collector != nullptr) { |
1039 | for (const auto& pair : subgraphs) { |
1040 | GraphDef def; |
1041 | pair.second->ToGraphDef(&def); |
1042 | *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto(); |
1043 | options.graph_collector->CollectPartitionedGraph(def); |
1044 | } |
1045 | } |
1046 | |
1047 | const auto& node_name_to_control_ret = |
1048 | optimized_graph_info.node_name_to_control_ret; |
1049 | // We must preserve control returns in each of the function components, |
1050 | // otherwise after function inlining we might prune side-effectful nodes. |
1051 | const auto control_ret = |
1052 | [&node_name_to_control_ret](const Node* n) -> absl::optional<string> { |
1053 | const auto it = node_name_to_control_ret.find(n->name()); |
1054 | return it != node_name_to_control_ret.end() |
1055 | ? absl::make_optional<string>(it->second) |
1056 | : absl::nullopt; |
1057 | }; |
1058 | |
1059 | auto data = std::make_unique<MultiDeviceFunctionData>( |
1060 | function_name, function_key, optimized_graph_info.num_return_nodes, |
1061 | std::move(optimized_graph_info.lib_def), |
1062 | std::move(optimized_graph_info.ret_types)); |
1063 | |
1064 | int i = 0; |
1065 | // Generate a random function_name to avoid one function reuse the partition |
1066 | // function instantiated by another function. |
1067 | FunctionLibraryDefinition* data_lib_def = &data->lib_def_; |
1068 | FunctionNameGenerator name_generator( |
1069 | data_lib_def, absl::StrCat(function_name, "_" , random::New64())); |
1070 | auto num_subgraphs = subgraphs.size(); |
1071 | gtl::InlinedVector<Status, 4> instantiate_status(num_subgraphs); |
1072 | BlockingCounter counter(static_cast<int>(num_subgraphs)); |
1073 | auto runner = [this, num_subgraphs](std::function<void()> fn) { |
1074 | // NOTE: Only use thread pool to instantiate sub-function when there are |
1075 | // more than 8 sub-functions. We want to avoid cost of switching thread when |
1076 | // there are only a few sub-functions. |
1077 | if (default_thread_pool_ != nullptr && num_subgraphs > 8) { |
1078 | default_thread_pool_->Schedule(fn); |
1079 | } else { |
1080 | fn(); |
1081 | } |
1082 | }; |
1083 | |
1084 | // Before instantiating component functions, determine synchronous execution. |
1085 | data->enable_sync_execution = false; |
1086 | if (options.allow_small_function_optimizations) { |
1087 | data->enable_sync_execution = true; |
1088 | for (const auto& pair : subgraphs) { |
1089 | ComponentFunctionData* comp_data = &data->glue_[pair.first]; |
1090 | const Graph* subgraph = pair.second.get(); |
1091 | comp_data->async_attributes = |
1092 | AsyncAttributes(subgraph, options.allow_control_flow_sync_execution); |
1093 | if (comp_data->async_attributes.summary() == |
1094 | AsyncAttributes::kAsyncRequired) { |
1095 | data->enable_sync_execution = false; |
1096 | } |
1097 | } |
1098 | } |
1099 | |
1100 | // Instantiate each component function (subgraph). |
1101 | for (const auto& pair : subgraphs) { |
1102 | Status* status = &instantiate_status[i]; |
1103 | string unique_name = name_generator.GetName(); |
1104 | ComponentFunctionData* comp_data = &data->glue_[pair.first]; |
1105 | runner([this, &pair, dev_set, comp_data, unique_name, data_lib_def, |
1106 | &control_ret, &options, status, &counter, &data] { |
1107 | const string& target = pair.first; |
1108 | |
1109 | const string& device_type = |
1110 | dev_set->FindDeviceByName(target)->device_type(); |
1111 | Graph* subgraph = pair.second.get(); |
1112 | |
1113 | bool ints_on_device = |
1114 | (device_type == "TPU" || device_type == "XLA_CPU" || |
1115 | device_type == "XLA_GPU" || options.int_args_and_retvals_on_device); |
1116 | status->Update(UpdateArgAndRetvalMetadata( |
1117 | subgraph, &comp_data->arg_indices, &comp_data->ret_indices, |
1118 | &comp_data->arg_alloc_attrs, &comp_data->ret_alloc_attrs, |
1119 | ints_on_device)); |
1120 | if (!status->ok()) { |
1121 | counter.DecrementCount(); |
1122 | return; |
1123 | } |
1124 | FunctionDef shard; |
1125 | status->Update( |
1126 | GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard)); |
1127 | if (!status->ok()) { |
1128 | counter.DecrementCount(); |
1129 | return; |
1130 | } |
1131 | status->Update(data_lib_def->AddFunctionDef(shard)); |
1132 | if (!status->ok()) { |
1133 | counter.DecrementCount(); |
1134 | return; |
1135 | } |
1136 | FunctionLibraryRuntime::InstantiateOptions opts; |
1137 | opts.executor_type = options.executor_type; |
1138 | opts.target = target; |
1139 | opts.lib_def = data_lib_def; |
1140 | opts.create_kernels_eagerly = options.create_kernels_eagerly; |
1141 | opts.state_handle = options.state_handle; |
1142 | opts.allow_small_function_optimizations = data->enable_sync_execution; |
1143 | opts.allow_control_flow_sync_execution = |
1144 | options.allow_control_flow_sync_execution; |
1145 | AttrValue ints_on_device_attr; |
1146 | ints_on_device_attr.set_b(options.int_args_and_retvals_on_device); |
1147 | shard.mutable_attr()->insert( |
1148 | {FunctionLibraryDefinition::kIntsOnDeviceAttr, ints_on_device_attr}); |
1149 | auto attrs = AttrSlice(&shard.attr()); |
1150 | VLOG(1) << "Start instantiating component function " << unique_name |
1151 | << " on device " << target; |
1152 | VLOG(4) << DebugString(shard); |
1153 | |
1154 | auto* component_handle = new FunctionLibraryRuntime::Handle; |
1155 | auto done = [this, status, unique_name, comp_data, component_handle, |
1156 | &data, &counter](const Status& s) { |
1157 | status->Update(s); |
1158 | |
1159 | VLOG(1) << "Finished instantiating component function " << unique_name |
1160 | << " with handle " << *component_handle << " status: " << s; |
1161 | if (status->ok()) { |
1162 | { |
1163 | mutex_lock l(mu_); |
1164 | if (function_data_[*component_handle]->is_cross_process()) { |
1165 | data->is_cross_process_ = true; |
1166 | } |
1167 | } |
1168 | comp_data->handle = *component_handle; |
1169 | } |
1170 | delete component_handle; |
1171 | counter.DecrementCount(); |
1172 | }; |
1173 | |
1174 | FunctionLibraryRuntime* flr = GetFLR(opts.target); |
1175 | if (flr != nullptr) { |
1176 | // Initialize local function synchronously. |
1177 | Status s = flr->Instantiate(unique_name, attrs, opts, component_handle); |
1178 | done(s); |
1179 | } else { |
1180 | opts.ret_indices = comp_data->ret_indices; |
1181 | // Initialize remote function asynchronously. |
1182 | InstantiateRemote(unique_name, attrs, opts, component_handle, done); |
1183 | } |
1184 | }); |
1185 | i += 1; |
1186 | } |
1187 | counter.Wait(); |
1188 | StatusGroup group; |
1189 | for (auto& status : instantiate_status) { |
1190 | group.Update(status); |
1191 | } |
1192 | TF_RETURN_IF_ERROR(group.as_summary_status()); |
1193 | |
1194 | *handle = AddMultiDeviceHandle(std::move(data), function_key); |
1195 | VLOG(2) << "Instantiated MultiDevice function \"" << function_name |
1196 | << "\" with handle " << *handle; |
1197 | return OkStatus(); |
1198 | } |
1199 | |
1200 | Status ProcessFunctionLibraryRuntime::GetOutputDevices( |
1201 | FunctionLibraryRuntime::Handle handle, |
1202 | std::vector<Device*>* output_devices) const { |
1203 | MultiDeviceFunctionData* data = IsMultiDevice(handle); |
1204 | if (data == nullptr) { |
1205 | return errors::InvalidArgument( |
1206 | "Failed for find multi-device function handle " , handle); |
1207 | } |
1208 | |
1209 | for (const auto& pair : data->glue_) { |
1210 | const ComponentFunctionData& comp_data = pair.second; |
1211 | DCHECK(comp_data.ret_alloc_attrs.size() == comp_data.ret_indices.size()); |
1212 | if (comp_data.ret_indices.empty()) { |
1213 | continue; |
1214 | } |
1215 | |
1216 | const string& target = pair.first; |
1217 | FunctionLibraryRuntime* target_flr = GetFLR(target); |
1218 | Device* target_device = nullptr; |
1219 | Device* host = nullptr; |
1220 | if (target_flr == nullptr) { |
1221 | if (!data->has_remote_outputs) { |
1222 | data->has_remote_outputs = true; |
1223 | } |
1224 | target_device = device_set()->FindDeviceByName(target); |
1225 | string remote_host; |
1226 | TF_RETURN_IF_ERROR( |
1227 | DeviceNameUtils::DeviceNameToCpuDeviceName(target, &remote_host)); |
1228 | host = device_set()->FindDeviceByName(remote_host); |
1229 | } else { |
1230 | target_device = target_flr->device(); |
1231 | } |
1232 | output_devices->resize(data->num_outputs_); |
1233 | for (int j = 0; j < comp_data.ret_indices.size(); ++j) { |
1234 | int ret_index = comp_data.ret_indices[j]; |
1235 | if (data->ret_types_[ret_index] == DT_RESOURCE) { |
1236 | (*output_devices)[ret_index] = target_device; |
1237 | } else { |
1238 | (*output_devices)[ret_index] = |
1239 | comp_data.ret_alloc_attrs[j].on_host() ? host : target_device; |
1240 | } |
1241 | } |
1242 | } |
1243 | |
1244 | return OkStatus(); |
1245 | } |
1246 | |
1247 | Status ProcessFunctionLibraryRuntime::PrepareRunMultiDevice( |
1248 | const FunctionLibraryRuntime::Options& opts, |
1249 | FunctionLibraryRuntime::Handle handle, |
1250 | const MultiDeviceFunctionData** data) const { |
1251 | if (opts.create_rendezvous) { |
1252 | // FLR->Run() is the default entry point. It checks for cancellation, |
1253 | // creates rendezvous, etc. |
1254 | // Letting create_rendezvous through will do the wrong thing - each |
1255 | // component function will get a separate rendezvous created by its FLR. |
1256 | return errors::Internal( |
1257 | "Cannot call ProcessFunctionLibraryRuntime::Run with " |
1258 | "create_rendezvous=true. Please run the function " |
1259 | "using FunctionLibraryRuntime::Run" ); |
1260 | } |
1261 | |
1262 | *data = IsMultiDevice(handle); |
1263 | if (*data == nullptr) { |
1264 | return errors::NotFound("Multi-device function handle " , handle, |
1265 | "not found. Was the function instantiated?" ); |
1266 | } |
1267 | |
1268 | // Check whether we have the right rendezvous. |
1269 | if (opts.rendezvous && (*data)->is_cross_process_ && |
1270 | !opts.rendezvous->is_cross_process()) { |
1271 | return errors::InvalidArgument( |
1272 | "Running a cross process function " , (*data)->function_name_, |
1273 | " without an appropriate cross process Rendezvous." ); |
1274 | } |
1275 | |
1276 | return OkStatus(); |
1277 | } |
1278 | |
1279 | std::vector<string> ProcessFunctionLibraryRuntime::GetOrderedSubgraphs( |
1280 | const MultiDeviceFunctionData* data) const { |
1281 | std::vector<string> subgraph_keys; |
1282 | subgraph_keys.reserve(data->glue_.size()); |
1283 | for (const auto& pair : data->glue_) { |
1284 | subgraph_keys.push_back(pair.first); |
1285 | } |
1286 | auto send_first_ordering = [&](const string& a, const string& b) { |
1287 | auto a_summary = data->glue_.at(a).async_attributes.summary(); |
1288 | auto b_summary = data->glue_.at(b).async_attributes.summary(); |
1289 | if (a_summary == b_summary) { |
1290 | return false; |
1291 | } |
1292 | if (a_summary == AsyncAttributes::kSendOnly) { |
1293 | return true; |
1294 | } |
1295 | return false; |
1296 | }; |
1297 | std::sort(subgraph_keys.begin(), subgraph_keys.end(), send_first_ordering); |
1298 | return subgraph_keys; |
1299 | } |
1300 | |
1301 | Status ProcessFunctionLibraryRuntime::RunMultiDeviceSync( |
1302 | const FunctionLibraryRuntime::Options& opts, |
1303 | FunctionLibraryRuntime::Handle outer_handle, std::vector<FunctionRet>* rets, |
1304 | std::function<Status(const ComponentFunctionData& comp_data, |
1305 | InternalArgs* args)> |
1306 | get_component_args) const { |
1307 | const MultiDeviceFunctionData* data; |
1308 | Status prepare_status = PrepareRunMultiDevice(opts, outer_handle, &data); |
1309 | if (!prepare_status.ok()) { |
1310 | return prepare_status; |
1311 | } |
1312 | |
1313 | FunctionLibraryRuntime::Options opts_copy = opts; |
1314 | |
1315 | // Sort the subgraphs topologically before execution to avoid deadlock: |
1316 | // |
1317 | // Because subgraphs will not execute in parallel here, dependencies between |
1318 | // subgraphs cannot be resolved automatically. In contrast, with multi- |
1319 | // threaded execution, we launch all subgraphs at once, asynchronously, and |
1320 | // allow any to block mid-execution while its dependencies are resolved. |
1321 | // |
1322 | // In this synchronous execution path, currently supported ops with inter- |
1323 | // subgraph dependencies are send and receive. As `_Send` and `_HostSend` |
1324 | // are non-blocking, we run subgraphs with those first, and those with |
1325 | // the blocking '_Recv' and '_HostRecv' ops will have their dependencies |
1326 | // resolved before execution. |
1327 | // |
1328 | // We assume that the partitioning has a valid deadlock-free ordering and the |
1329 | // safety of running synchronously has already been confirmed by this point. |
1330 | std::vector<string> subgraph_keys = GetOrderedSubgraphs(data); |
1331 | |
1332 | for (const string& target : subgraph_keys) { |
1333 | const ComponentFunctionData& comp_data = data->glue_.at(target); |
1334 | FunctionLibraryRuntime::Handle comp_handle = comp_data.handle; |
1335 | |
1336 | opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs; |
1337 | opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs; |
1338 | |
1339 | InternalArgs comp_args; |
1340 | Status args_status = get_component_args(comp_data, &comp_args); |
1341 | if (!args_status.ok()) { |
1342 | VLOG(2) << "Failed to get component function arguments: " << args_status; |
1343 | return args_status; |
1344 | } |
1345 | rets->resize(data->num_outputs_); |
1346 | |
1347 | VLOG(1) << "Running component function on device " << target << " from " |
1348 | << data->function_name_ << " with handle " << comp_handle; |
1349 | FunctionLibraryRuntime* flr = GetFLR(target); |
1350 | if (flr != nullptr) { |
1351 | opts_copy.remote_execution = false; |
1352 | // When target device has private thread pool, use the target device |
1353 | // runner |
1354 | thread::ThreadPool* pool = flr->device()->tensorflow_device_thread_pool(); |
1355 | opts_copy.runner = (pool == nullptr) ? opts.runner : flr->runner(); |
1356 | VLOG(4) << " with " << opts_copy.DebugString(); |
1357 | |
1358 | std::vector<Tensor> comp_tensor_rets; |
1359 | Status run_status = |
1360 | flr->RunSync(opts_copy, comp_handle, GetLocalArgs(comp_args.args), |
1361 | &comp_tensor_rets); |
1362 | if (!run_status.ok()) { |
1363 | VLOG(2) << "Component function execution failed: " << run_status; |
1364 | const string function_and_msg = strings::StrCat( |
1365 | errors::FormatFunctionForError(data->function_name_), " " , |
1366 | run_status.error_message()); |
1367 | if (opts.rendezvous != nullptr) opts.rendezvous->StartAbort(run_status); |
1368 | return errors::CreateWithUpdatedMessage(run_status, function_and_msg); |
1369 | } else { |
1370 | VLOG(2) << "Component function execution succeeded." ; |
1371 | for (int i = 0; i < comp_tensor_rets.size(); ++i) { |
1372 | (*rets)[comp_data.ret_indices[i]] = comp_tensor_rets[i]; |
1373 | } |
1374 | } |
1375 | } else { |
1376 | // Fall back to DistributedFunctionLibraryRuntime for remote execution. |
1377 | opts_copy.remote_execution = true; |
1378 | VLOG(4) << " with " << opts_copy.DebugString(); |
1379 | |
1380 | std::vector<std::unique_ptr<CleanUpItem>> cleanup_items; |
1381 | Notification n; |
1382 | Status s; |
1383 | std::vector<FunctionRet> comp_rets; |
1384 | RunInternal(opts_copy, comp_handle, comp_args.args, &comp_rets, |
1385 | &cleanup_items, [&n, &s](const Status& status) { |
1386 | s.Update(status); |
1387 | n.Notify(); |
1388 | }); |
1389 | n.WaitForNotification(); |
1390 | return s; |
1391 | } |
1392 | } |
1393 | return OkStatus(); |
1394 | } |
1395 | |
1396 | void ProcessFunctionLibraryRuntime::RunMultiDeviceAsync( |
1397 | const FunctionLibraryRuntime::Options& opts, |
1398 | FunctionLibraryRuntime::Handle outer_handle, std::vector<FunctionRet>* rets, |
1399 | std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items, |
1400 | FunctionLibraryRuntime::DoneCallback done, |
1401 | std::function<Status(const ComponentFunctionData& comp_data, |
1402 | InternalArgs* args)> |
1403 | get_component_args) const { |
1404 | const MultiDeviceFunctionData* data; |
1405 | Status prepare_status = PrepareRunMultiDevice(opts, outer_handle, &data); |
1406 | if (!prepare_status.ok()) { |
1407 | done(prepare_status); |
1408 | return; |
1409 | } |
1410 | |
1411 | // A locally created cancellation manager, used only when the caller does not |
1412 | // provide one in argument. |
1413 | std::shared_ptr<CancellationManager> local_cm; |
1414 | CancellationManager* cm = opts.cancellation_manager; |
1415 | if (cm == nullptr) { |
1416 | local_cm = std::make_shared<CancellationManager>(); |
1417 | cm = local_cm.get(); |
1418 | } |
1419 | |
1420 | auto* refcounted_done = new ReffedStatusCallback(std::move(done)); |
1421 | for (int i = 0; i < data->glue_.size(); ++i) { |
1422 | refcounted_done->Ref(); |
1423 | } |
1424 | |
1425 | FunctionLibraryRuntime::Options opts_copy = opts; |
1426 | for (const auto& pair : data->glue_) { |
1427 | const string& target = pair.first; |
1428 | const ComponentFunctionData& comp_data = pair.second; |
1429 | FunctionLibraryRuntime::Handle comp_handle = pair.second.handle; |
1430 | |
1431 | opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs; |
1432 | opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs; |
1433 | opts_copy.cancellation_manager = cm; |
1434 | |
1435 | InternalArgs comp_args; |
1436 | Status s = get_component_args(comp_data, &comp_args); |
1437 | if (!s.ok()) { |
1438 | VLOG(2) << "Failed to get component function arguments: " << s; |
1439 | refcounted_done->UpdateStatus(s); |
1440 | refcounted_done->Unref(); |
1441 | cm->StartCancel(); |
1442 | continue; |
1443 | } |
1444 | std::vector<FunctionRet>* comp_rets = new std::vector<FunctionRet>; |
1445 | rets->resize(data->num_outputs_); |
1446 | |
1447 | auto component_fn_callback = [comp_rets, rets, comp_data, refcounted_done, |
1448 | cm, local_cm, data, comp_handle, |
1449 | target](const Status& status) { |
1450 | if (!status.ok()) { |
1451 | VLOG(2) << "Component function execution on target " << target |
1452 | << " from " << data->function_name_ << " with handle " |
1453 | << comp_handle << " failed: " << status; |
1454 | const string function_and_msg = strings::StrCat( |
1455 | errors::FormatFunctionForError(data->function_name_), " " , |
1456 | status.error_message()); |
1457 | refcounted_done->UpdateStatus( |
1458 | errors::CreateWithUpdatedMessage(status, function_and_msg)); |
1459 | // Cancel the execution of other component functions. |
1460 | cm->StartCancel(); |
1461 | } else { |
1462 | VLOG(2) << "Component function execution on target " << target |
1463 | << " from " << data->function_name_ << " with handle " |
1464 | << comp_handle << " succeeded." ; |
1465 | for (int i = 0; i < comp_rets->size(); ++i) { |
1466 | (*rets)[comp_data.ret_indices[i]] = (*comp_rets)[i]; |
1467 | } |
1468 | } |
1469 | delete comp_rets; |
1470 | // refcounted_done is thread-safe |
1471 | refcounted_done->Unref(); |
1472 | }; |
1473 | |
1474 | FunctionLibraryRuntime* flr = GetFLR(target); |
1475 | if (flr != nullptr) { |
1476 | opts_copy.remote_execution = false; |
1477 | // When target device has private thread pool, use the target device |
1478 | // runner |
1479 | thread::ThreadPool* pool = flr->device()->tensorflow_device_thread_pool(); |
1480 | opts_copy.runner = (pool == nullptr) ? opts.runner : flr->runner(); |
1481 | |
1482 | VLOG(1) << "Running component function on device " << target << " from " |
1483 | << data->function_name_ << " with handle " << comp_handle; |
1484 | VLOG(4) << " with " << opts_copy.DebugString(); |
1485 | |
1486 | std::vector<Tensor>* comp_tensor_rets = new std::vector<Tensor>; |
1487 | flr->Run( |
1488 | opts_copy, comp_handle, GetLocalArgs(comp_args.args), |
1489 | comp_tensor_rets, |
1490 | TensorsToFunctionRetsDoneCallback(comp_rets, comp_tensor_rets, |
1491 | std::move(component_fn_callback))); |
1492 | } else { |
1493 | opts_copy.remote_execution = true; |
1494 | |
1495 | VLOG(1) << "Running component function on device " << target << " from " |
1496 | << data->function_name_ << " with handle " << comp_handle; |
1497 | VLOG(4) << " with " << opts_copy.DebugString(); |
1498 | |
1499 | RunInternal(opts_copy, comp_handle, comp_args.args, comp_rets, |
1500 | cleanup_items, std::move(component_fn_callback)); |
1501 | } |
1502 | } |
1503 | refcounted_done->Unref(); |
1504 | } |
1505 | |
1506 | Status ProcessFunctionLibraryRuntime::Instantiate( |
1507 | const string& function_name, AttrSlice attrs, |
1508 | const FunctionLibraryRuntime::InstantiateOptions& options, |
1509 | FunctionLibraryRuntime::Handle* handle) { |
1510 | if (options.is_multi_device_function) { |
1511 | return InstantiateMultiDevice(function_name, attrs, options, handle); |
1512 | } |
1513 | |
1514 | *handle = kInvalidHandle; |
1515 | FunctionLibraryRuntime* flr = GetFLR(options.target); |
1516 | if (flr != nullptr) { |
1517 | return flr->Instantiate(function_name, attrs, options, handle); |
1518 | } |
1519 | |
1520 | Status status; |
1521 | Notification notification; |
1522 | InstantiateRemote(function_name, attrs, options, handle, |
1523 | [&status, ¬ification](const Status& s) { |
1524 | status = s; |
1525 | notification.Notify(); |
1526 | }); |
1527 | notification.WaitForNotification(); |
1528 | return status; |
1529 | } |
1530 | |
1531 | Status ProcessFunctionLibraryRuntime::IsCrossProcess( |
1532 | FunctionLibraryRuntime::Handle handle, bool* is_cross_process) const { |
1533 | tf_shared_lock l(mu_); |
1534 | const auto& mdevice_it = mdevice_data_.find(handle); |
1535 | if (mdevice_it != mdevice_data_.end()) { |
1536 | *is_cross_process = mdevice_it->second->is_cross_process_; |
1537 | return OkStatus(); |
1538 | } |
1539 | const auto& it = function_data_.find(handle); |
1540 | if (it != function_data_.end()) { |
1541 | *is_cross_process = it->second->is_cross_process(); |
1542 | return OkStatus(); |
1543 | } |
1544 | return errors::InvalidArgument("Handle " , handle, " not found." ); |
1545 | } |
1546 | |
1547 | void ProcessFunctionLibraryRuntime::InstantiateRemote( |
1548 | const string& function_name, AttrSlice attrs, |
1549 | const FunctionLibraryRuntime::InstantiateOptions& options, |
1550 | FunctionLibraryRuntime::Handle* handle, |
1551 | FunctionLibraryRuntime::DoneCallback done) { |
1552 | if (parent_ == nullptr) { |
1553 | done(errors::Internal( |
1554 | "Currently don't support instantiating functions on device: " , |
1555 | options.target)); |
1556 | return; |
1557 | } |
1558 | auto target = options.target; |
1559 | VLOG(1) << "ProcessFLR Instantiate: " << function_name << " on: " << target; |
1560 | string function_key = Canonicalize(function_name, attrs, options); |
1561 | FunctionData* f; |
1562 | { |
1563 | mutex_lock l(mu_); |
1564 | FunctionLibraryRuntime::Handle h = |
1565 | gtl::FindWithDefault(table_, function_key, kInvalidHandle); |
1566 | if (h == kInvalidHandle || function_data_.count(h) == 0) { |
1567 | h = AddHandleLocked(function_key, target, kInvalidHandle); |
1568 | } |
1569 | f = function_data_[h].get(); |
1570 | *handle = h; |
1571 | } |
1572 | f->DistributedInit( |
1573 | parent_, function_name, |
1574 | options.lib_def == nullptr ? *lib_def_ : *options.lib_def, attrs, options, |
1575 | [this, function_name, target, handle, done](const Status& s) { |
1576 | VLOG(1) << "ProcessFLR Instantiate [success]: " << function_name |
1577 | << " on: " << target << " with handle: " << *handle |
1578 | << " (this: " << this << ")" ; |
1579 | done(s); |
1580 | }); |
1581 | } |
1582 | |
1583 | Status ProcessFunctionLibraryRuntime::RemoveHandle( |
1584 | FunctionLibraryRuntime::Handle handle) { |
1585 | mutex_lock l(mu_); |
1586 | table_.erase(function_data_[handle]->function_key()); |
1587 | function_data_.erase(handle); |
1588 | return OkStatus(); |
1589 | } |
1590 | |
1591 | Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle( |
1592 | FunctionLibraryRuntime::Handle handle) { |
1593 | std::unique_ptr<MultiDeviceFunctionData> mdata; |
1594 | { |
1595 | mutex_lock l(mu_); |
1596 | auto it = mdevice_data_.find(handle); |
1597 | --it->second->instantiation_counter_; |
1598 | if (it->second->instantiation_counter_ != 0) { |
1599 | return OkStatus(); |
1600 | } |
1601 | mdata = std::move(it->second); |
1602 | table_.erase(mdata->function_key_); |
1603 | mdevice_data_.erase(it); |
1604 | } |
1605 | |
1606 | // If we are here we are releasing the last instantiation of `handle`. |
1607 | // Release all component function handles. |
1608 | Status overall_status; |
1609 | for (const auto& it : mdata->glue_) { |
1610 | const string& device = it.first; |
1611 | FunctionLibraryRuntime::Handle flr_handle = it.second.handle; |
1612 | FunctionLibraryRuntime* flr = GetFLR(device); |
1613 | if (flr == nullptr) { |
1614 | // TODO(nareshmodi): Implement DeregisterGraph call to remote device if |
1615 | // parent is not null. |
1616 | if (parent_ != nullptr) { |
1617 | return errors::Unimplemented( |
1618 | "Releasing a multi-device component handle on a remote device is " |
1619 | "not yet implemented." ); |
1620 | } |
1621 | return errors::InvalidArgument( |
1622 | "Failed to find FunctionLibraryRuntime for device " , device, |
1623 | " when releasing multi-device function handle " , handle); |
1624 | } |
1625 | Status status = flr->ReleaseHandle(flr_handle); |
1626 | if (!status.ok()) { |
1627 | overall_status = status; |
1628 | } |
1629 | } |
1630 | |
1631 | return overall_status; |
1632 | } |
1633 | |
1634 | Status ProcessFunctionLibraryRuntime::ReleaseHandle( |
1635 | FunctionLibraryRuntime::Handle handle) { |
1636 | // Return directly if all function handles has already been released. |
1637 | if (flr_map_ == nullptr) return OkStatus(); |
1638 | |
1639 | if (IsMultiDevice(handle)) { |
1640 | return ReleaseMultiDeviceHandle(handle); |
1641 | } |
1642 | |
1643 | FunctionLibraryRuntime* flr = nullptr; |
1644 | string target_device; |
1645 | { |
1646 | mutex_lock l(mu_); |
1647 | CHECK_EQ(1, function_data_.count(handle)) << " handle: " << handle; |
1648 | target_device = function_data_[handle]->target_device(); |
1649 | } |
1650 | flr = GetFLR(target_device); |
1651 | if (flr != nullptr) { |
1652 | return flr->ReleaseHandle(handle); |
1653 | } |
1654 | return errors::InvalidArgument("Handle not found: " , handle); |
1655 | } |
1656 | |
1657 | void ProcessFunctionLibraryRuntime::CleanupCreatedRendezvous( |
1658 | const Rendezvous* created_rendezvous, const int64_t step_id) const { |
1659 | if (created_rendezvous) { |
1660 | DCHECK(rendezvous_factory_); |
1661 | created_rendezvous->Unref(); |
1662 | Status s = rendezvous_factory_.CleanUp(step_id); |
1663 | if (!s.ok()) { |
1664 | LOG(ERROR) << s; |
1665 | } |
1666 | } |
1667 | } |
1668 | |
1669 | FunctionLibraryRuntime::DoneCallback |
1670 | ProcessFunctionLibraryRuntime::ApplyCleanUpToDoneCallback( |
1671 | std::vector<std::unique_ptr<CleanUpItem>>* items, |
1672 | FunctionLibraryRuntime::DoneCallback done, const int64_t step_id, |
1673 | const Rendezvous* created_rendezvous) const { |
1674 | return [this, items, done = std::move(done), step_id, |
1675 | created_rendezvous](const Status& status) { |
1676 | this->CleanupCreatedRendezvous(created_rendezvous, step_id); |
1677 | auto* local_status = new Status(status); |
1678 | CleanUp(items, [local_status, done](const Status& cleanup_status) { |
1679 | local_status->Update(cleanup_status); |
1680 | done(*local_status); |
1681 | delete local_status; |
1682 | }); |
1683 | delete items; |
1684 | }; |
1685 | } |
1686 | |
1687 | Status ProcessFunctionLibraryRuntime::CreateRendezvous( |
1688 | FunctionLibraryRuntime::Options& opts, |
1689 | Rendezvous** created_rendezvous) const { |
1690 | DCHECK(opts.rendezvous == nullptr); |
1691 | if (!rendezvous_factory_) { |
1692 | return errors::FailedPrecondition( |
1693 | "The caller does not provide a rendezvous and " |
1694 | "ProcessFunctionLibraryRuntime was created without a rendezvous " |
1695 | "factory." ); |
1696 | } |
1697 | Status s = rendezvous_factory_(opts.step_id, device_mgr_, created_rendezvous); |
1698 | if (s.ok()) { |
1699 | opts.rendezvous = *created_rendezvous; |
1700 | opts.create_rendezvous = false; |
1701 | } |
1702 | return s; |
1703 | } |
1704 | |
1705 | Status ProcessFunctionLibraryRuntime::GetComponentArgs( |
1706 | const gtl::ArraySlice<Tensor> args, |
1707 | const ProcessFunctionLibraryRuntime::ComponentFunctionData& comp_data, |
1708 | ProcessFunctionLibraryRuntime::InternalArgs* comp_args) { |
1709 | // "Index"s of _Arg nodes are unique when all arguments are local Tensors. |
1710 | for (const auto& it : comp_data.arg_indices) { |
1711 | if (it.index >= args.size()) { |
1712 | return errors::InvalidArgument("index " , it.index, |
1713 | " is out of range [0, " , args.size(), ")" ); |
1714 | } |
1715 | if (it.sub_index >= 0) { |
1716 | const Tensor& t = args[it.index]; |
1717 | if (t.dtype() != DT_RESOURCE) { |
1718 | return errors::InvalidArgument("Got unexpected sub_index " , |
1719 | it.sub_index, " for argument " , |
1720 | it.index); |
1721 | } |
1722 | const auto& handles = t.flat<ResourceHandle>(); |
1723 | if (it.sub_index >= handles.size()) { |
1724 | return errors::InvalidArgument("Sub_index " , it.sub_index, |
1725 | "is out of range [0," , handles.size(), |
1726 | ") for argument " , it.index); |
1727 | } |
1728 | comp_args->args.push_back(Tensor(handles(it.sub_index))); |
1729 | } else { |
1730 | comp_args->args.push_back(args[it.index]); |
1731 | } |
1732 | } |
1733 | return OkStatus(); |
1734 | } |
1735 | |
1736 | #if !defined(IS_MOBILE_PLATFORM) |
1737 | Status ProcessFunctionLibraryRuntime::GetComponentArgs( |
1738 | const FunctionArgsInterface& args, |
1739 | const ProcessFunctionLibraryRuntime::ComponentFunctionData& comp_data, |
1740 | ProcessFunctionLibraryRuntime::InternalArgs* comp_args) { |
1741 | for (int i = 0; i < comp_data.arg_indices.size(); ++i) { |
1742 | const FunctionArgIndex index = comp_data.arg_indices.at(i); |
1743 | Tensor tensor; |
1744 | if (args.GetLocalArg(index, &tensor).ok()) { |
1745 | comp_args->args.push_back(std::move(tensor)); |
1746 | } else { |
1747 | eager::RemoteTensorHandle remote_handle; |
1748 | TF_RETURN_IF_ERROR(args.GetRemoteArg(index, &remote_handle)); |
1749 | comp_args->remote_args.emplace_back( |
1750 | std::make_unique<eager::RemoteTensorHandle>( |
1751 | std::move(remote_handle))); |
1752 | comp_args->args.push_back(comp_args->remote_args.back().get()); |
1753 | } |
1754 | } |
1755 | return OkStatus(); |
1756 | } |
1757 | #endif // IS_MOBILE_PLATFORM |
1758 | |
1759 | void ProcessFunctionLibraryRuntime::Run( |
1760 | const FunctionLibraryRuntime::Options& opts, |
1761 | FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args, |
1762 | std::vector<Tensor>* rets, |
1763 | FunctionLibraryRuntime::DoneCallback done) const { |
1764 | FunctionLibraryRuntime::Options new_opts = opts; |
1765 | Rendezvous* created_rendezvous = nullptr; |
1766 | if (!opts.rendezvous) { |
1767 | Status s = CreateRendezvous(new_opts, &created_rendezvous); |
1768 | if (!s.ok()) { |
1769 | done(s); |
1770 | return; |
1771 | } |
1772 | } |
1773 | |
1774 | auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>; |
1775 | done = ApplyCleanUpToDoneCallback(cleanup_items, std::move(done), |
1776 | new_opts.step_id, created_rendezvous); |
1777 | std::vector<FunctionRet>* function_rets = new std::vector<FunctionRet>; |
1778 | done = [rets, function_rets, done = std::move(done)](const Status& s) { |
1779 | Status status = s; |
1780 | if (status.ok()) { |
1781 | status.Update(FunctionRetsToTensors(function_rets, rets)); |
1782 | } |
1783 | delete function_rets; |
1784 | done(status); |
1785 | }; |
1786 | bool multi_device = HasMultiDeviceHandle(handle); |
1787 | if (multi_device) { |
1788 | auto get_component_args = [&args](const ComponentFunctionData& comp_data, |
1789 | InternalArgs* comp_args) -> Status { |
1790 | return GetComponentArgs(args, comp_data, comp_args); |
1791 | }; |
1792 | return RunMultiDeviceAsync(new_opts, handle, function_rets, cleanup_items, |
1793 | std::move(done), std::move(get_component_args)); |
1794 | } |
1795 | std::vector<FunctionArg> local_args; |
1796 | for (const auto& tensor : args) { |
1797 | local_args.push_back(tensor); |
1798 | } |
1799 | RunInternal(new_opts, handle, local_args, function_rets, cleanup_items, |
1800 | std::move(done)); |
1801 | } |
1802 | |
1803 | // This method handles the simple remote call case (not multi-device). |
1804 | void ProcessFunctionLibraryRuntime::RunInternal( |
1805 | const FunctionLibraryRuntime::Options& opts, |
1806 | FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<FunctionArg> args, |
1807 | std::vector<FunctionRet>* rets, |
1808 | std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items, |
1809 | FunctionLibraryRuntime::DoneCallback done) const { |
1810 | FunctionLibraryRuntime* flr = nullptr; |
1811 | string target_device; |
1812 | FunctionLibraryRuntime::LocalHandle local_handle; |
1813 | { |
1814 | tf_shared_lock l(mu_); |
1815 | auto iter = function_data_.find(handle); |
1816 | if (iter == function_data_.end()) { |
1817 | done(errors::NotFound("Handle: " , handle, " not found." )); |
1818 | return; |
1819 | } |
1820 | FunctionData* function_data = iter->second.get(); |
1821 | target_device = function_data->target_device(); |
1822 | local_handle = function_data->local_handle(); |
1823 | } |
1824 | |
1825 | if (!opts.remote_execution) { |
1826 | done( |
1827 | errors::InvalidArgument("ProcessFunctionLibraryRuntime::Run should " |
1828 | "only be called for multi-device functions or " |
1829 | "for remote execution." )); |
1830 | return; |
1831 | } |
1832 | |
1833 | flr = GetFLR(target_device); |
1834 | if (flr != nullptr) { |
1835 | auto rendezvous = opts.rendezvous; |
1836 | string source_device = opts.source_device; |
1837 | DeviceContext* device_context; |
1838 | Status s = GetDeviceContext(source_device, &device_context); |
1839 | if (!s.ok()) { |
1840 | done(s); |
1841 | return; |
1842 | } |
1843 | int64_t src_incarnation, target_incarnation; |
1844 | s = GetDeviceIncarnation(source_device, &src_incarnation); |
1845 | s.Update(GetDeviceIncarnation(target_device, &target_incarnation)); |
1846 | if (!s.ok()) { |
1847 | done(s); |
1848 | return; |
1849 | } |
1850 | |
1851 | std::vector<Tensor> local_args = GetLocalArgs(args); |
1852 | |
1853 | // Send the args over to the target device. |
1854 | s = SendTensors(source_device, target_device, "arg_" , src_incarnation, |
1855 | local_args, device_context, opts.args_alloc_attrs, |
1856 | rendezvous); |
1857 | if (!s.ok()) { |
1858 | done(s); |
1859 | return; |
1860 | } |
1861 | const std::vector<AllocatorAttributes>& rets_alloc_attrs = |
1862 | opts.rets_alloc_attrs; |
1863 | std::vector<Tensor>* remote_rets = new std::vector<Tensor>; |
1864 | flr->Run(opts, handle, local_args, remote_rets, |
1865 | [source_device, target_device, target_incarnation, rendezvous, |
1866 | device_context, rets_alloc_attrs, remote_rets, rets, |
1867 | done = std::move(done)](const Status& status) mutable { |
1868 | if (!status.ok()) { |
1869 | delete remote_rets; |
1870 | done(status); |
1871 | return; |
1872 | } |
1873 | int64_t num_returns = remote_rets->size(); |
1874 | delete remote_rets; |
1875 | // Now receive the return values from the target. |
1876 | std::vector<Tensor>* recv_tensors = new std::vector<Tensor>; |
1877 | ReceiveTensorsAsync(target_device, source_device, "ret_" , |
1878 | target_incarnation, num_returns, |
1879 | device_context, rets_alloc_attrs, rendezvous, |
1880 | recv_tensors, |
1881 | TensorsToFunctionRetsDoneCallback( |
1882 | rets, recv_tensors, std::move(done))); |
1883 | }); |
1884 | return; |
1885 | } |
1886 | if (parent_ != nullptr) { |
1887 | auto cleanup_item = std::make_unique<CleanUpItem>(); |
1888 | cleanup_item->device = target_device; |
1889 | cleanup_item->step_id = opts.step_id; |
1890 | cleanup_item->local_handle = local_handle; |
1891 | cleanup_items->emplace_back(std::move(cleanup_item)); |
1892 | parent_->Run(opts, local_handle, args, rets, std::move(done)); |
1893 | return; |
1894 | } |
1895 | done(errors::Internal("Could not find device" )); |
1896 | } |
1897 | |
1898 | void ProcessFunctionLibraryRuntime::Run( |
1899 | const FunctionLibraryRuntime::Options& opts, |
1900 | FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame, |
1901 | FunctionLibraryRuntime::DoneCallback done) const { |
1902 | std::vector<Tensor> args; |
1903 | args.reserve(frame->num_args()); |
1904 | for (size_t i = 0; i < frame->num_args(); ++i) { |
1905 | const Tensor* arg; |
1906 | Status s = frame->GetArg(i, &arg); |
1907 | args.emplace_back(*arg); |
1908 | if (!s.ok()) { |
1909 | done(s); |
1910 | } |
1911 | } |
1912 | std::vector<Tensor>* rets = new std::vector<Tensor>; |
1913 | rets->reserve(frame->num_retvals()); |
1914 | |
1915 | Run(opts, handle, args, rets, |
1916 | |
1917 | [frame, rets, done = std::move(done)](const Status& status) { |
1918 | std::unique_ptr<std::vector<Tensor>> rets_releaser(rets); |
1919 | |
1920 | if (!status.ok()) { |
1921 | done(status); |
1922 | return; |
1923 | } |
1924 | |
1925 | if (rets->size() != frame->num_retvals()) { |
1926 | done(errors::Internal( |
1927 | "Number of return values from function (" , rets->size(), |
1928 | ") did not match expected number of return values (" , |
1929 | frame->num_retvals(), ")." )); |
1930 | return; |
1931 | } |
1932 | |
1933 | for (size_t i = 0; i < frame->num_retvals(); ++i) { |
1934 | Status s = frame->SetRetval(i, (*rets)[i]); |
1935 | if (!s.ok()) { |
1936 | done(s); |
1937 | return; |
1938 | } |
1939 | } |
1940 | done(OkStatus()); |
1941 | }); |
1942 | } |
1943 | |
1944 | Status ProcessFunctionLibraryRuntime::RunSync( |
1945 | const FunctionLibraryRuntime::Options& orig_opts, |
1946 | FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args, |
1947 | std::vector<Tensor>* rets) const { |
1948 | MultiDeviceFunctionData* multi_device_data = IsMultiDevice(handle); |
1949 | if (multi_device_data && multi_device_data->enable_sync_execution) { |
1950 | metrics::IncrementTestCounter("pflr_runsync" , "sync" ); |
1951 | FunctionLibraryRuntime::Options new_opts = orig_opts; |
1952 | Rendezvous* created_rendezvous = nullptr; |
1953 | if (!new_opts.rendezvous) { |
1954 | TF_RETURN_IF_ERROR(CreateRendezvous(new_opts, &created_rendezvous)); |
1955 | } |
1956 | |
1957 | std::vector<FunctionRet> function_rets; |
1958 | auto get_component_args = [&args](const ComponentFunctionData& comp_data, |
1959 | InternalArgs* comp_args) { |
1960 | return GetComponentArgs(args, comp_data, comp_args); |
1961 | }; |
1962 | |
1963 | Status status = RunMultiDeviceSync(new_opts, handle, &function_rets, |
1964 | std::move(get_component_args)); |
1965 | CleanupCreatedRendezvous(created_rendezvous, new_opts.step_id); |
1966 | status.Update(FunctionRetsToTensors(&function_rets, rets)); |
1967 | return status; |
1968 | } else { |
1969 | // TODO(b/207484417): Either handle or avoid/delete this fallback path. |
1970 | metrics::IncrementTestCounter("pflr_runsync" , "async" ); |
1971 | Notification n; |
1972 | Status s; |
1973 | Run(orig_opts, handle, args, rets, [&n, &s](const Status& status) { |
1974 | s.Update(status); |
1975 | n.Notify(); |
1976 | }); |
1977 | n.WaitForNotification(); |
1978 | return s; |
1979 | } |
1980 | } |
1981 | |
1982 | Status ProcessFunctionLibraryRuntime::RunSync( |
1983 | const FunctionLibraryRuntime::Options& opts, |
1984 | FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame) const { |
1985 | // TODO(b/207485199): Implement this as synchronous code. |
1986 | Notification n; |
1987 | Status s; |
1988 | Run(opts, handle, frame, [&n, &s](const Status& status) { |
1989 | s.Update(status); |
1990 | n.Notify(); |
1991 | }); |
1992 | n.WaitForNotification(); |
1993 | return s; |
1994 | } |
1995 | |
1996 | void ProcessFunctionLibraryRuntime::Run( |
1997 | const FunctionLibraryRuntime::Options& opts, |
1998 | FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args, |
1999 | std::vector<FunctionRet>* rets, |
2000 | FunctionLibraryRuntime::DoneCallback done) const { |
2001 | bool has_remote_outputs = false; |
2002 | const MultiDeviceFunctionData* data = IsMultiDevice(handle); |
2003 | if (data != nullptr) { |
2004 | has_remote_outputs = data->has_remote_outputs; |
2005 | } |
2006 | if (!args.HasRemoteOrPackedInputs() && !has_remote_outputs) { |
2007 | const std::vector<Tensor> local_inputs = args.GetLocalTensors(); |
2008 | std::vector<Tensor>* tensor_rets = new std::vector<Tensor>; |
2009 | return Run( |
2010 | opts, handle, local_inputs, tensor_rets, |
2011 | TensorsToFunctionRetsDoneCallback(rets, tensor_rets, std::move(done))); |
2012 | } |
2013 | |
2014 | FunctionLibraryRuntime::Options new_opts = opts; |
2015 | Rendezvous* created_rendezvous = nullptr; |
2016 | if (!opts.rendezvous) { |
2017 | Status s = CreateRendezvous(new_opts, &created_rendezvous); |
2018 | if (!s.ok()) { |
2019 | done(s); |
2020 | return; |
2021 | } |
2022 | } |
2023 | |
2024 | #if defined(IS_MOBILE_PLATFORM) |
2025 | done(errors::Unimplemented( |
2026 | "Remote inputs are not available on mobile devices." )); |
2027 | return; |
2028 | #else // !IS_MOBILE_PLATFORM |
2029 | auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>; |
2030 | done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id, |
2031 | created_rendezvous); |
2032 | |
2033 | auto get_component_args = [&args](const ComponentFunctionData& comp_data, |
2034 | InternalArgs* comp_args) -> Status { |
2035 | return GetComponentArgs(args, comp_data, comp_args); |
2036 | }; |
2037 | return RunMultiDeviceAsync(new_opts, handle, rets, cleanup_items, |
2038 | std::move(done), std::move(get_component_args)); |
2039 | #endif // !IS_MOBILE_PLATFORM |
2040 | } |
2041 | |
2042 | void ProcessFunctionLibraryRuntime::CleanUp( |
2043 | std::vector<std::unique_ptr<CleanUpItem>>* items, |
2044 | FunctionLibraryRuntime::DoneCallback done) const { |
2045 | auto* refcounted_done = new ReffedStatusCallback(std::move(done)); |
2046 | for (auto& item : *items) { |
2047 | refcounted_done->Ref(); |
2048 | auto* flr = GetFLR(item->device); |
2049 | if (flr != nullptr) { |
2050 | // TODO(fishx): cleanup state for local execution. |
2051 | refcounted_done->UpdateStatus( |
2052 | errors::Internal("Cleanup items shouldn't contain local item." )); |
2053 | refcounted_done->Unref(); |
2054 | } else if (parent_ != nullptr) { |
2055 | parent_->CleanUp(item->step_id, item->local_handle, |
2056 | [refcounted_done](const Status& status) { |
2057 | if (!status.ok()) { |
2058 | refcounted_done->UpdateStatus(status); |
2059 | } |
2060 | // refcounted_done is thread-safe |
2061 | refcounted_done->Unref(); |
2062 | }); |
2063 | } else { |
2064 | refcounted_done->UpdateStatus( |
2065 | errors::Internal("Could not find device in cleanup." )); |
2066 | refcounted_done->Unref(); |
2067 | } |
2068 | } |
2069 | refcounted_done->Unref(); |
2070 | } |
2071 | |
2072 | Status ProcessFunctionLibraryRuntime::Clone( |
2073 | Env* env, int graph_def_version, const OptimizerOptions& optimizer_options, |
2074 | std::unique_ptr<FunctionLibraryDefinition>* out_lib_def, |
2075 | std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr, |
2076 | bool skip_flib_def) const { |
2077 | if (skip_flib_def) { |
2078 | *out_lib_def = std::make_unique<FunctionLibraryDefinition>( |
2079 | lib_def_->default_registry(), FunctionDefLibrary{}); |
2080 | } else { |
2081 | *out_lib_def = std::make_unique<FunctionLibraryDefinition>(*lib_def_); |
2082 | } |
2083 | *out_pflr = std::make_unique<ProcessFunctionLibraryRuntime>( |
2084 | device_mgr_, env, config_ ? &(*config_) : nullptr, graph_def_version, |
2085 | out_lib_def->get(), optimizer_options, default_thread_pool_, parent_, |
2086 | session_metadata_, rendezvous_factory_); |
2087 | { |
2088 | tf_shared_lock l(mu_); |
2089 | for (auto* d : composite_devices_) (*out_pflr)->AddCompositeDevice(d); |
2090 | } |
2091 | return OkStatus(); |
2092 | } |
2093 | |
2094 | } // namespace tensorflow |
2095 | |