1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/data/captured_function.h"
16
17#include <utility>
18
19#include "absl/time/clock.h"
20#include "tensorflow/core/common_runtime/function.h"
21#include "tensorflow/core/common_runtime/step_stats_collector.h"
22#include "tensorflow/core/data/dataset_utils.h"
23#include "tensorflow/core/data/stats_utils.h"
24#include "tensorflow/core/framework/attr_value.pb.h"
25#include "tensorflow/core/framework/cancellation.h"
26#include "tensorflow/core/framework/function.h"
27#include "tensorflow/core/framework/function_handle_cache.h"
28#include "tensorflow/core/framework/op_kernel.h"
29#include "tensorflow/core/framework/stats_aggregator.h"
30#include "tensorflow/core/lib/core/errors.h"
31#include "tensorflow/core/lib/gtl/optional.h"
32#include "tensorflow/core/lib/random/random.h"
33#include "tensorflow/core/lib/strings/strcat.h"
34#include "tensorflow/core/platform/errors.h"
35#include "tensorflow/core/platform/notification.h"
36#include "tensorflow/core/profiler/lib/traceme.h"
37
38#if !defined(IS_MOBILE_PLATFORM)
39#include "tensorflow/core/grappler/grappler_item.h"
40#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
41#endif // !IS_MOBILE_PLATFORM
42
43namespace tensorflow {
44namespace data {
45namespace {
46
47constexpr char kAllowSmallFunctionOptimizations[] =
48 "allow_small_function_optimizations";
49
50// Simplistic implementation of the `StepStatsCollectorInterface` that only
51// cares about collecting the CPU time needed to execute a captured function.
52class SimpleStepStatsCollector : public StepStatsCollectorInterface {
53 public:
54 void IncrementProcessingTime(int64_t delta) {
55 mutex_lock l(mu_);
56 processing_time_ += delta;
57 }
58
59 NodeExecStatsInterface* CreateNodeExecStats(const NodeDef* node) override {
60 return new SimpleNodeExecStats(this);
61 }
62
63 string ReportAllocsOnResourceExhausted(const string& err) override {
64 return "";
65 }
66
67 int64_t processing_time() {
68 tf_shared_lock l(mu_);
69 return processing_time_;
70 }
71
72 private:
73 class SimpleNodeExecStats : public NodeExecStatsInterface {
74 public:
75 explicit SimpleNodeExecStats(SimpleStepStatsCollector* step_stats_collector)
76 : step_stats_collector_(step_stats_collector) {}
77
78 void Done(const string& device) override {
79 step_stats_collector_->IncrementProcessingTime(end_time_ns_ -
80 start_time_ns_);
81 delete this;
82 }
83
84 void RecordExecutorStarted() override {
85 start_time_ns_ = absl::GetCurrentTimeNanos();
86 }
87
88 void RecordComputeStarted() override {}
89
90 void RecordComputeEnded() override {}
91
92 void RecordExecutorEnded() override {
93 end_time_ns_ = absl::GetCurrentTimeNanos();
94 }
95
96 bool TrackAllocations() const override { return false; }
97
98 void SetMemory(OpKernelContext* ctx) override {}
99
100 void SetOutput(int slot, const Tensor* tensor) override {}
101
102 void SetScheduled(int64_t nanos) override {}
103
104 private:
105 int64_t start_time_ns_ = 0;
106 int64_t end_time_ns_ = 0;
107 SimpleStepStatsCollector* step_stats_collector_; // Not owned.
108 };
109
110 mutex mu_;
111 int64_t processing_time_ TF_GUARDED_BY(mu_) = 0;
112};
113
114Status GetCapturedInput(const CapturedFunction* const func, int index,
115 const Tensor** out) {
116 if (TF_PREDICT_FALSE(index >= func->captured_inputs().size())) {
117 return errors::OutOfRange(
118 "Out of range access to captured inputs for function ",
119 func->func().name(), ". Index: ", index,
120 ". Num captured inputs: ", func->captured_inputs().size());
121 }
122 *out = &func->captured_inputs()[index];
123 return OkStatus();
124}
125
126Status RunShortCircuit(const ShortCircuitInfo& info,
127 const std::vector<Tensor>& args,
128 const CapturedFunction* const func,
129 std::vector<Tensor>* rets) {
130 VLOG(3) << "Running function " << func->func().name() << " short circuit";
131 const int num_args = args.size();
132 rets->reserve(info.indices.size());
133 for (size_t i = 0; i < info.indices.size(); ++i) {
134 if (info.indices[i] < num_args) {
135 rets->push_back(args[info.indices[i]]);
136 } else {
137 const Tensor* captured_input;
138 TF_RETURN_IF_ERROR(
139 GetCapturedInput(func, info.indices[i] - num_args, &captured_input));
140 rets->push_back(*captured_input);
141 }
142 }
143 return OkStatus();
144}
145
146Status RunShortCircuit(const ShortCircuitInfo& info, std::vector<Tensor>&& args,
147 const CapturedFunction* const func,
148 std::vector<Tensor>* rets) {
149 VLOG(3) << "Running function " << func->func().name() << " short circuit";
150 const int num_args = args.size();
151 rets->reserve(info.indices.size());
152 for (size_t i = 0; i < info.indices.size(); ++i) {
153 if (info.indices[i] < num_args) {
154 if (info.can_move[i]) {
155 rets->push_back(std::move(args[info.indices[i]]));
156 } else {
157 rets->push_back(args[info.indices[i]]);
158 }
159 } else {
160 const Tensor* captured_input;
161 TF_RETURN_IF_ERROR(
162 GetCapturedInput(func, info.indices[i] - num_args, &captured_input));
163 rets->push_back(*captured_input);
164 }
165 }
166 return OkStatus();
167}
168
169Status CreateShortCircuitInfo(OpKernelConstruction* ctx,
170 const NameAttrList& func,
171 ShortCircuitInfo* info) {
172 auto& indices = info->indices;
173
174 FunctionLibraryRuntime::Handle fn_handle;
175 TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
176 func.name(), AttrSlice(&func.attr()), &fn_handle));
177 auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() {
178 Status s = ctx->function_library()->ReleaseHandle(fn_handle);
179 if (!s.ok()) {
180 LOG(WARNING) << "Failed to release handle: " << s.error_message();
181 }
182 });
183
184 // If the function contains any stateful operations, we conservatively execute
185 // the entire function.
186 if (ctx->function_library()->IsStateful(func.name())) {
187 return OkStatus();
188 }
189
190 const FunctionBody* fn_body =
191 ctx->function_library()->GetFunctionBody(fn_handle);
192 indices.resize(fn_body->ret_nodes.size());
193
194 for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) {
195 Node* ret_node = fn_body->ret_nodes[i];
196 Node* ret_input_node;
197 TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node));
198
199 while (ret_input_node->def().op() == "Identity") {
200 TF_RETURN_IF_ERROR(ret_input_node->input_node(0, &ret_input_node));
201 }
202
203 if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) {
204 TF_RETURN_IF_ERROR(
205 GetNodeAttr(ret_input_node->def(), "index", &(indices[i])));
206 } else {
207 indices.clear();
208 break;
209 }
210 }
211
212 // Compute the `can_move` vector.
213 if (!indices.empty()) {
214 auto& can_move = info->can_move;
215 std::map<int, int> last_use;
216 for (size_t i = 0; i < indices.size(); ++i) {
217 last_use[indices[i]] = i;
218 }
219 can_move.resize(indices.size());
220 for (int i = 0, end = indices.size(); i < end; ++i) {
221 can_move[i] = last_use[indices[i]] == i;
222 }
223 }
224
225 return OkStatus();
226}
227
228Status CreateFunctionLibraryDefinition(
229 const FunctionLibraryDefinition* lib_def, const string& func_name,
230 std::unique_ptr<FunctionLibraryDefinition>* result) {
231 DCHECK(lib_def != nullptr);
232 const FunctionDef* fdef = lib_def->Find(func_name);
233 if (TF_PREDICT_FALSE(fdef == nullptr)) {
234 return errors::FailedPrecondition(strings::StrCat(
235 "Could not find required function definition ", func_name));
236 }
237 *result = std::make_unique<FunctionLibraryDefinition>(
238 lib_def->ReachableDefinitions(*fdef));
239 return (*result)->CopyFunctionDefFrom(func_name, *lib_def);
240}
241
242Status LookupFunction(const FunctionLibraryDefinition& lib_def,
243 const string& name, const FunctionDef** fdef) {
244 *fdef = lib_def.Find(name);
245 if (*fdef == nullptr) {
246 return errors::InvalidArgument(
247 "Failed to find function ", name,
248 " in function library: ", lib_def.ToProto().DebugString());
249 }
250 return OkStatus();
251}
252
253class CallFrameBase : public CallFrameInterface {
254 public:
255 explicit CallFrameBase(DataTypeSlice ret_types)
256 : ret_types_(ret_types), retvals_(ret_types.size()) {}
257
258 // Caller methods.
259 Status ConsumeRetvals(std::vector<Tensor>* retvals) {
260 retvals->reserve(retvals_.size());
261 int i = 0;
262 for (auto&& val : retvals_) {
263 if (!val) {
264 return errors::Internal("No return value for index ", i, ".");
265 }
266 retvals->emplace_back(std::move(val.value()));
267 ++i;
268 }
269 return OkStatus();
270 }
271
272 size_t num_retvals() const override { return retvals_.size(); }
273
274 // Callee methods.
275 Status SetRetval(int index, const Tensor& val) override {
276 const int retvals_size = retvals_.size();
277 if (index < retvals_size && val.dtype() == ret_types_[index] &&
278 !retvals_[index]) {
279 retvals_[index] = val;
280 return OkStatus();
281 } else if (index >= retvals_size) {
282 return errors::InvalidArgument("Return value ", index,
283 " is out of range.");
284 } else if (val.dtype() != ret_types_[index]) {
285 return errors::InvalidArgument("Expected type ",
286 DataTypeString(ret_types_[index]),
287 " for return value ", index, " but got ",
288 DataTypeString(val.dtype()), ".");
289 } else {
290 return errors::Internal("Attempted to set return value ", index,
291 " more than once.");
292 }
293 }
294
295 private:
296 DataTypeSlice ret_types_;
297 std::vector<gtl::optional<Tensor>> retvals_;
298 TF_DISALLOW_COPY_AND_ASSIGN(CallFrameBase);
299};
300
301class OwnedArgsCallFrame : public CallFrameBase {
302 public:
303 OwnedArgsCallFrame(std::vector<Tensor>&& args,
304 const std::vector<Tensor>* captured_inputs,
305 DataTypeSlice ret_types)
306 : CallFrameBase(ret_types),
307 args_(std::move(args)),
308 captured_inputs_(captured_inputs) {}
309
310 size_t num_args() const override {
311 return args_.size() + captured_inputs_->size();
312 }
313
314 // Callee methods.
315 Status GetArg(int index, const Tensor** val) override {
316 const int args_size = args_.size();
317 const int captured_inputs_size = captured_inputs_->size();
318 if (index < args_size) {
319 *val = &args_[index];
320 return OkStatus();
321 } else if (index < args_size + captured_inputs_size) {
322 *val = &(*captured_inputs_)[index - args_.size()];
323 return OkStatus();
324 } else {
325 return errors::InvalidArgument("Argument ", index, " is out of range.");
326 }
327 }
328
329 // Since we own the argument tensors in `args_`, we can implement
330 // `ConsumeArg()` for those arguments.
331 void ConsumeArg(int index, Tensor* val) override {
332 DCHECK_GE(index, 0);
333 DCHECK_LT(index, args_.size());
334 *val = std::move(args_[index]);
335 }
336 bool CanConsumeArg(int index) const override {
337 return index >= 0 && index < static_cast<int>(args_.size());
338 }
339
340 private:
341 std::vector<Tensor> args_;
342 const std::vector<Tensor>* const captured_inputs_; // Not owned.
343};
344
345class BorrowedArgsCallFrame : public CallFrameBase {
346 public:
347 BorrowedArgsCallFrame(const std::vector<Tensor>& args,
348 const std::vector<Tensor>* captured_inputs,
349 DataTypeSlice ret_types)
350 : CallFrameBase(ret_types),
351 args_(args),
352 captured_inputs_(captured_inputs) {}
353
354 size_t num_args() const override {
355 return args_.size() + captured_inputs_->size();
356 }
357
358 // Callee methods.
359 Status GetArg(int index, const Tensor** val) override {
360 const int args_size = args_.size();
361 const int captured_inputs_size = captured_inputs_->size();
362 if (index < args_size) {
363 *val = &args_[index];
364 return OkStatus();
365 } else if (index < args_size + captured_inputs_size) {
366 *val = &(*captured_inputs_)[index - args_size];
367 return OkStatus();
368 } else {
369 return errors::InvalidArgument("Argument ", index, " is out of range.");
370 }
371 }
372
373 private:
374 const std::vector<Tensor>& args_; // Not owned.
375 const std::vector<Tensor>* const captured_inputs_; // Not owned.
376};
377
378} // namespace
379
380Status MakeIteratorFromInputElement(
381 IteratorContext* ctx, const IteratorBase* parent,
382 const std::vector<Tensor>& input_element, int64_t thread_index,
383 const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
384 std::unique_ptr<IteratorBase>* out_iterator) {
385 return MakeIteratorFromInputElement(ctx, parent, input_element, thread_index,
386 inst_captured_func, prefix, out_iterator,
387 /*node=*/nullptr);
388}
389
390Status MakeIteratorFromInputElement(
391 IteratorContext* ctx, const IteratorBase* parent,
392 const std::vector<Tensor>& input_element, int64_t thread_index,
393 const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
394 std::unique_ptr<IteratorBase>* out_iterator,
395 const std::shared_ptr<model::Node>& node) {
396 std::vector<Tensor> return_values;
397
398 TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(
399 ctx, input_element, &return_values, node));
400
401 if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT &&
402 TensorShapeUtils::IsScalar(return_values[0].shape()))) {
403 return errors::InvalidArgument(
404 "Function must return a single scalar of dtype DT_VARIANT.");
405 }
406
407 // Retrieve the dataset that was created in `f`.
408 DatasetBase* returned_dataset;
409 TF_RETURN_IF_ERROR(
410 GetDatasetFromVariantTensor(return_values[0], &returned_dataset));
411
412 // Create an iterator for the dataset that was returned by `f`.
413 std::string iterator_prefix = strings::StrCat(prefix, "[", thread_index, "]");
414
415 return returned_dataset->MakeIterator(MakeNestedIteratorContext(ctx), parent,
416 iterator_prefix, out_iterator);
417}
418
419IteratorContext MakeNestedIteratorContext(IteratorContext* ctx) {
420 // Strip out any split providers so that they don't apply to sub-iterators.
421 if (ctx->split_providers().empty()) {
422 return *ctx;
423 }
424 IteratorContext::Params params(ctx);
425 params.split_providers.clear();
426 return IteratorContext(std::move(params));
427}
428
429/* static */
430Status FunctionMetadata::Create(
431 OpKernelConstruction* ctx, const string& func_name, Params params,
432 std::shared_ptr<FunctionMetadata>* out_metadata) {
433 NameAttrList func;
434 TF_RETURN_IF_ERROR(ctx->GetAttr(func_name, &func));
435 return Create(ctx, std::move(func), params, out_metadata);
436}
437
438Status FunctionMetadata::Create(
439 OpKernelConstruction* ctx, NameAttrList&& func, Params params,
440 std::shared_ptr<FunctionMetadata>* out_metadata) {
441 out_metadata->reset(new FunctionMetadata(std::move(func), params));
442 TF_RETURN_IF_ERROR(CreateFunctionLibraryDefinition(
443 ctx->function_library()->GetFunctionLibraryDefinition(),
444 (*out_metadata)->func_.name(), &(*out_metadata)->lib_def_));
445 TF_RETURN_IF_ERROR(CreateShortCircuitInfo(
446 ctx, (*out_metadata)->func_, &(*out_metadata)->short_circuit_info_));
447 const FunctionDef* fdef;
448 TF_RETURN_IF_ERROR(LookupFunction(*(*out_metadata)->lib_def(),
449 (*out_metadata)->func().name(), &fdef));
450
451 auto attr = fdef->attr().find(FunctionLibraryDefinition::kIntsOnDeviceAttr);
452 if (attr != fdef->attr().end() && attr->second.b()) {
453 VLOG(1) << "Disabling multi-device execution for a function that uses the "
454 << FunctionLibraryDefinition::kIntsOnDeviceAttr << " attribute.";
455 (*out_metadata)->use_multi_device_function_ = false;
456 return OkStatus();
457 }
458 auto validate_arg = [](const OpDef::ArgDef& arg) {
459 if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) {
460 VLOG(1) << "Disabling multi-device execution for a function with "
461 << "a vector argument " << arg.name() << ".";
462 return false;
463 }
464 return true;
465 };
466 for (const auto& arg : fdef->signature().input_arg()) {
467 if (!validate_arg(arg)) {
468 (*out_metadata)->use_multi_device_function_ = false;
469 return OkStatus();
470 }
471 }
472 for (const auto& arg : fdef->signature().output_arg()) {
473 if (!validate_arg(arg)) {
474 (*out_metadata)->use_multi_device_function_ = false;
475 return OkStatus();
476 }
477 }
478 return OkStatus();
479}
480
481/* static */
482Status CapturedFunction::Create(
483 OpKernelContext* ctx, std::shared_ptr<const FunctionMetadata> metadata,
484 const string& argument_name,
485 std::unique_ptr<CapturedFunction>* out_function) {
486 OpInputList inputs;
487 TF_RETURN_IF_ERROR(ctx->input_list(argument_name, &inputs));
488 std::vector<Tensor> captured_inputs(inputs.begin(), inputs.end());
489 return Create(ctx, std::move(metadata), std::move(captured_inputs),
490 out_function);
491}
492
493/* static */
494Status CapturedFunction::Create(
495 OpKernelContext* ctx, std::shared_ptr<const FunctionMetadata> metadata,
496 std::vector<Tensor>&& captured_inputs,
497 std::unique_ptr<CapturedFunction>* out_function) {
498 *out_function = absl::WrapUnique(
499 new CapturedFunction(std::move(metadata), std::move(captured_inputs)));
500 return OkStatus();
501}
502
503Status CapturedFunction::AddToGraph(
504 SerializationContext* ctx, DatasetBase::DatasetGraphDefBuilder* b,
505 std::vector<Node*>* other_arguments,
506 DataTypeVector* other_arguments_types) const {
507 other_arguments->reserve(captured_inputs_.size());
508 other_arguments_types->reserve(captured_inputs_.size());
509 for (const Tensor& t : captured_inputs_) {
510 Node* node;
511 if (!ctx->is_graph_rewrite()) {
512 TF_RETURN_IF_ERROR(b->AddDatasetOrTensor(ctx, t, &node));
513 } else {
514 TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
515 DCHECK_NE(ctx->input_list(), nullptr);
516 ctx->input_list()->emplace_back(node->name(), t);
517 }
518 other_arguments->emplace_back(node);
519 other_arguments_types->emplace_back(t.dtype());
520 }
521 TF_RETURN_IF_ERROR(
522 b->AddFunction(ctx, metadata_->func().name(), *metadata_->lib_def()));
523 return OkStatus();
524}
525
526Status CapturedFunction::Instantiate(
527 IteratorContext* ctx, std::unique_ptr<InstantiatedCapturedFunction>*
528 instantiated_captured_function) {
529 return CapturedFunction::Instantiate(InstantiateCapturedFunctionParams(ctx),
530 instantiated_captured_function);
531}
532
533// TODO(b/190831948): Check whether the function creates a resource and if so,
534// produce a warning.
535Status CapturedFunction::Instantiate(
536 InstantiateCapturedFunctionParams params,
537 std::unique_ptr<InstantiatedCapturedFunction>*
538 instantiated_captured_function) {
539 // The context's runtime will be used for all subsequent calls.
540 FunctionLibraryRuntime* lib = params.flr;
541 FunctionLibraryRuntime::InstantiateOptions inst_opts;
542 inst_opts.lib_def = metadata_->lib_def();
543 inst_opts.create_kernels_eagerly = true;
544 inst_opts.default_device_to_target = metadata_->use_default_device();
545 inst_opts.config_proto =
546 lib->config_proto() ? *lib->config_proto() : ConfigProto();
547 if (GetExperiments().contains(kAllowSmallFunctionOptimizations)) {
548 inst_opts.allow_small_function_optimizations = true;
549 } else {
550 if (!metadata_->use_inter_op_parallelism()) {
551 inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
552 }
553 }
554 inst_opts.is_multi_device_function = metadata_->use_multi_device_function();
555 if (!params.function_handle_cache) {
556 // If the caller does not provide a cache, we use the FLR cache.
557 inst_opts.use_function_cache = true;
558 }
559
560 // We infer the target device from the function library runtime.
561 DCHECK(lib->device() != nullptr);
562 inst_opts.target = lib->device()->name();
563
564 // Maps from a CompositeDevice name to underlying physical device names.
565 absl::flat_hash_map<string, std::vector<string>> composite_devices;
566
567 if (inst_opts.is_multi_device_function) {
568 // Compute devices of non-captured inputs.
569 //
570 // We infer the number of non-captured inputs by subtracting the number
571 // of captured inputs from the number of input arguments and we infer the
572 // input devices from the function library runtime.
573 const FunctionDef* fdef;
574 TF_RETURN_IF_ERROR(
575 LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef));
576 size_t num_non_captured_inputs =
577 fdef->signature().input_arg_size() - captured_inputs_.size();
578 for (size_t i = 0; i < num_non_captured_inputs; ++i) {
579 inst_opts.input_devices.push_back(inst_opts.target);
580 }
581 // Compute devices of captured inputs.
582 // TODO(jsimsa): Correctly handle tensors on devices other than CPU:0.
583 Device* cpu_device;
584 TF_RETURN_IF_ERROR(lib->device_mgr()->LookupDevice("CPU:0", &cpu_device));
585 std::unordered_map<int, DtypeAndPartialTensorShape>&
586 input_resource_variable_dtypes_and_shapes =
587 inst_opts.input_resource_dtypes_and_shapes;
588 for (size_t i = 0; i < captured_inputs_.size(); ++i) {
589 const auto& input = captured_inputs_[i];
590 DataType dtype = input.dtype();
591 if (dtype == DT_RESOURCE) {
592 const auto& handles = input.flat<ResourceHandle>();
593 const ResourceHandle& handle0 = handles(0);
594 string composite_device;
595 auto iter = fdef->arg_attr().find(num_non_captured_inputs + i);
596 if (iter != fdef->arg_attr().end()) {
597 auto arg_attr = iter->second.attr().find("_composite_device");
598 if (arg_attr != iter->second.attr().end()) {
599 composite_device = arg_attr->second.s();
600 }
601 }
602 if (!composite_device.empty()) {
603 if (composite_devices.find(composite_device) ==
604 composite_devices.end()) {
605 for (int i = 0; i < handles.size(); ++i) {
606 composite_devices[composite_device].push_back(
607 handles(i).device());
608 }
609 }
610 inst_opts.input_devices.push_back(composite_device);
611 } else {
612 inst_opts.input_devices.push_back(handle0.device());
613 }
614 const auto& dtypes_and_shapes = handle0.dtypes_and_shapes();
615 // Set dtypes and shapes for resource variable inputs.
616 if (!dtypes_and_shapes.empty()) {
617 input_resource_variable_dtypes_and_shapes[num_non_captured_inputs +
618 i] =
619 dtypes_and_shapes.at(0);
620 }
621 } else if (MTypeFromDType(dtype) == HOST_MEMORY) {
622 inst_opts.input_devices.push_back(cpu_device->name());
623 } else {
624 // Fall back to using the function library runtime device.
625 inst_opts.input_devices.push_back(inst_opts.target);
626 }
627 }
628
629 for (const auto& it : composite_devices) {
630 inst_opts.composite_devices[it.first] = &it.second;
631 }
632
633 for (int i = 0, end = fdef->signature().output_arg_size(); i < end; ++i) {
634 inst_opts.output_devices.push_back(inst_opts.target);
635 }
636
637#if !defined(IS_MOBILE_PLATFORM)
638 grappler::GrapplerItem::OptimizationOptions optimization_options;
639 optimization_options.allow_pruning_stateful_and_dataset_ops = false;
640 ConfigProto config_proto = inst_opts.config_proto;
641 // Layout optimizations are excluded because they assume that ops without
642 // explicit device assignment will be placed on GPU (if available) but
643 // that's not the case for operations within tf.data functions.
644 config_proto.mutable_graph_options()
645 ->mutable_rewrite_options()
646 ->set_layout_optimizer(RewriterConfig::OFF);
647 // TODO(b/120437209): Re-enable constant folding.
648 config_proto.mutable_graph_options()
649 ->mutable_rewrite_options()
650 ->set_constant_folding(RewriterConfig::OFF);
651 inst_opts.optimize_graph_fn =
652 std::bind(tensorflow::grappler::OptimizeGraph, std::placeholders::_1,
653 std::placeholders::_2, std::placeholders::_3,
654 std::placeholders::_4, std::placeholders::_5,
655 std::move(config_proto), fdef->signature().name(),
656 std::move(optimization_options), std::placeholders::_6);
657#endif // !IS_MOBILE_PLATFORM
658 }
659
660 FunctionLibraryRuntime::Handle f_handle;
661 if (params.function_handle_cache) {
662 TF_RETURN_IF_ERROR(params.function_handle_cache->Instantiate(
663 metadata_->func().name(), AttrSlice(&metadata_->func().attr()),
664 inst_opts, &f_handle));
665 } else {
666 TF_RETURN_IF_ERROR(lib->Instantiate(metadata_->func().name(),
667 AttrSlice(&metadata_->func().attr()),
668 inst_opts, &f_handle));
669 }
670
671 DataTypeVector ret_types;
672 TF_RETURN_IF_ERROR(lib->GetRetTypes(f_handle, &ret_types));
673
674 bool is_multi_device;
675 TF_RETURN_IF_ERROR(IsMultiDevice(lib, &is_multi_device));
676 *instantiated_captured_function = absl::WrapUnique(
677 new InstantiatedCapturedFunction(lib, f_handle, std::move(ret_types),
678 *params.runner, this, is_multi_device));
679 return OkStatus();
680}
681
682Status CapturedFunction::CheckExternalState() const {
683 for (const auto& name : lib_def()->ListFunctionNames()) {
684 TF_RETURN_IF_ERROR(
685 IsFunctionStateful(*lib_def(), *(lib_def()->Find(name))));
686 }
687 return OkStatus();
688}
689
690CapturedFunction::CapturedFunction(
691 std::shared_ptr<const FunctionMetadata> metadata,
692 std::vector<Tensor> captured_inputs)
693 : metadata_(std::move(metadata)),
694 captured_inputs_(std::move(captured_inputs)) {}
695
696Status CapturedFunction::IsMultiDevice(FunctionLibraryRuntime* flr,
697 bool* is_multi_device) const {
698 if (!metadata_->use_multi_device_function()) {
699 *is_multi_device = false;
700 return OkStatus();
701 }
702
703 const FunctionDef* fdef;
704 TF_RETURN_IF_ERROR(
705 LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef));
706
707 Device* current_device = flr->device();
708 DeviceType current_device_type(current_device->device_type());
709 DeviceNameUtils::ParsedName current_device_name;
710 if (!DeviceNameUtils::ParseFullName(current_device->name(),
711 &current_device_name)) {
712 return errors::InvalidArgument("Failed to parse device name: ",
713 current_device->name());
714 }
715
716 // Check if any of the captured inputs are placed on a device not compatible
717 // with the current device. For non-captured inputs, we assume they are placed
718 // on the current device.
719 for (const auto& input : captured_inputs_) {
720 DataType dtype = input.dtype();
721 if (dtype == DT_RESOURCE) {
722 const ResourceHandle& handle = input.flat<ResourceHandle>()(0);
723 DeviceNameUtils::ParsedName resource_device_name;
724 if (!DeviceNameUtils::ParseFullName(handle.device(),
725 &resource_device_name)) {
726 return errors::InvalidArgument("Failed to parse device name: ",
727 handle.device());
728 }
729 if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
730 resource_device_name)) {
731 *is_multi_device = true;
732 return OkStatus();
733 }
734 }
735 }
736
737 // Check if all ops could be placed on the current device.
738 for (const auto& name : metadata_->lib_def()->ListFunctionNames()) {
739 const FunctionDef* fdef;
740 TF_RETURN_IF_ERROR(LookupFunction(*metadata_->lib_def(), name, &fdef));
741 for (const auto& node : fdef->node_def()) {
742 // Check if the op has a kernel available for the current device.
743 if (!KernelDefAvailable(current_device_type, node)) {
744 *is_multi_device = true;
745 return OkStatus();
746 }
747 // If the op has a requested device, check if the requested device is
748 // compatible with the current device.
749 if (!node.device().empty()) {
750 DeviceNameUtils::ParsedName node_device_name;
751 if (!DeviceNameUtils::ParseFullName(node.device(), &node_device_name)) {
752 return errors::InvalidArgument("Failed to parse device name: ",
753 node.device());
754 }
755 if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name,
756 node_device_name)) {
757 *is_multi_device = true;
758 return OkStatus();
759 }
760 }
761 }
762 }
763
764 *is_multi_device = false;
765 return OkStatus();
766}
767
768InstantiatedCapturedFunction::InstantiatedCapturedFunction(
769 FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
770 DataTypeVector ret_types, std::function<void(std::function<void()>)> runner,
771 CapturedFunction* captured_func, bool is_multi_device)
772 : lib_(lib),
773 f_handle_(f_handle),
774 ret_types_(std::move(ret_types)),
775 captured_runner_(std::move(runner)),
776 captured_func_(captured_func),
777 is_multi_device_(is_multi_device) {}
778
779Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
780 std::vector<Tensor>&& args,
781 std::vector<Tensor>* rets) const {
782 return Run(ctx, std::move(args), rets, /*node=*/nullptr);
783}
784
785Status InstantiatedCapturedFunction::Run(
786 IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
787 const std::shared_ptr<model::Node>& node) const {
788 auto& info = captured_func_->short_circuit_info();
789 if (!info.indices.empty()) {
790 return RunShortCircuit(info, std::move(args), captured_func_, rets);
791 }
792
793 FunctionLibraryRuntime::Options f_opts;
794 ScopedStepContainer step_container(
795 f_opts.step_id, [this](const string& name) {
796 lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
797 });
798 f_opts.step_container = &step_container;
799 f_opts.runner = ctx->runner();
800 f_opts.create_rendezvous = ShouldCreateRendezvous();
801 CancellationManager cancellation_manager(ctx->cancellation_manager());
802 f_opts.cancellation_manager = &cancellation_manager;
803 f_opts.collective_executor = ctx->collective_executor();
804
805 std::shared_ptr<SimpleStepStatsCollector> stats_collector;
806 if (node || ctx->stats_aggregator()) {
807 stats_collector = std::make_shared<SimpleStepStatsCollector>();
808 }
809 const bool collect_usage = node && ctx->model();
810 f_opts.stats_collector = stats_collector.get();
811
812 OwnedArgsCallFrame frame(std::move(args), &captured_func_->captured_inputs(),
813 ret_types_);
814 profiler::TraceMe activity(
815 [&] {
816 return profiler::TraceMeEncode("InstantiatedCapturedFunction::Run",
817 {{"id", f_opts.step_id}});
818 },
819 profiler::TraceMeLevel::kInfo);
820 if (node) {
821 // Resource usage for function execution is gathered from the executor.
822 // TODO(jsimsa): Factor out common code for Run, RunAsync, and
823 // RunWithBorrowedArguments
824 if (collect_usage) node->record_stop(EnvTime::NowNanos());
825 TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
826 if (ctx->stats_aggregator()) {
827 string prefix_with_func_name = strings::StrCat(
828 node->name(), stats_utils::kDelimiter, captured_func_->func().name());
829 ctx->stats_aggregator()->AddToHistogram(
830 stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
831 {static_cast<float>(stats_collector->processing_time())},
832 node->num_elements());
833 }
834 node->add_processing_time(stats_collector->processing_time());
835 if (collect_usage) node->record_start(EnvTime::NowNanos());
836 } else {
837 TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
838 }
839 return frame.ConsumeRetvals(rets);
840}
841
842Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
843 IteratorContext* ctx, const std::vector<Tensor>& args,
844 std::vector<Tensor>* ret) const {
845 return RunWithBorrowedArgs(ctx, args, ret, /*node=*/nullptr);
846}
847
848Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
849 IteratorContext* ctx, const std::vector<Tensor>& args,
850 std::vector<Tensor>* rets, const std::shared_ptr<model::Node>& node) const {
851 auto& info = captured_func_->short_circuit_info();
852 if (!info.indices.empty()) {
853 return RunShortCircuit(info, args, captured_func_, rets);
854 }
855
856 FunctionLibraryRuntime::Options f_opts;
857 ScopedStepContainer step_container(
858 f_opts.step_id, [this](const string& name) {
859 lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
860 });
861 f_opts.step_container = &step_container;
862 f_opts.runner = ctx->runner();
863 f_opts.create_rendezvous = ShouldCreateRendezvous();
864 CancellationManager cancellation_manager(ctx->cancellation_manager());
865 f_opts.cancellation_manager = &cancellation_manager;
866 f_opts.collective_executor = ctx->collective_executor();
867
868 std::shared_ptr<SimpleStepStatsCollector> stats_collector;
869 if (node || ctx->stats_aggregator()) {
870 stats_collector = std::make_shared<SimpleStepStatsCollector>();
871 }
872 const bool collect_usage = node && ctx->model();
873 f_opts.stats_collector = stats_collector.get();
874
875 BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
876 ret_types_);
877 profiler::TraceMe activity(
878 [&] {
879 return profiler::TraceMeEncode(
880 "InstantiatedCapturedFunction::RunWithBorrowedArgs",
881 {{"id", f_opts.step_id}});
882 },
883 profiler::TraceMeLevel::kInfo);
884 if (node) {
885 // Resource usage for function execution is gathered from the executor.
886 if (collect_usage) node->record_stop(EnvTime::NowNanos());
887 TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
888 if (ctx->stats_aggregator()) {
889 string prefix_with_func_name = strings::StrCat(
890 node->name(), stats_utils::kDelimiter, captured_func_->func().name());
891 ctx->stats_aggregator()->AddToHistogram(
892 stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
893 {static_cast<float>(stats_collector->processing_time())},
894 node->num_elements());
895 }
896 node->add_processing_time(stats_collector->processing_time());
897 if (collect_usage) node->record_start(EnvTime::NowNanos());
898 } else {
899 TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
900 }
901 return frame.ConsumeRetvals(rets);
902}
903
904Status InstantiatedCapturedFunction::RunInstantiated(
905 const std::vector<Tensor>& args, std::vector<Tensor>* rets) {
906 auto& info = captured_func_->short_circuit_info();
907 if (!info.indices.empty()) {
908 return RunShortCircuit(info, args, captured_func_, rets);
909 }
910
911 FunctionLibraryRuntime::Options f_opts;
912 ScopedStepContainer step_container(
913 f_opts.step_id, [this](const string& name) {
914 lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
915 });
916 f_opts.step_container = &step_container;
917 f_opts.runner = &captured_runner_;
918 f_opts.create_rendezvous = ShouldCreateRendezvous();
919 CancellationManager cancellation_manager;
920 f_opts.cancellation_manager = &cancellation_manager;
921
922 BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
923 ret_types_);
924 profiler::TraceMe activity(
925 [&] {
926 return profiler::TraceMeEncode(
927 "InstantiatedCapturedFunction::RunInstantiated",
928 {{"id", f_opts.step_id}});
929 },
930 profiler::TraceMeLevel::kInfo);
931 TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
932 return frame.ConsumeRetvals(rets);
933}
934
935void InstantiatedCapturedFunction::RunAsync(
936 IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
937 FunctionLibraryRuntime::DoneCallback done,
938 const std::shared_ptr<model::Node>& node) const {
939 auto& info = captured_func_->short_circuit_info();
940 if (!info.indices.empty()) {
941 // Run the `done` callback on a threadpool thread, because it will
942 // potentially do a non-trivial amount of (e.g. copying) work, and we may
943 // want to run that concurrently with the next invocation.
944 Status s = RunShortCircuit(info, std::move(args), captured_func_, rets);
945 (*ctx->runner())(
946 std::bind([s](FunctionLibraryRuntime::DoneCallback& done) { done(s); },
947 std::move(done)));
948 return;
949 }
950
951 // NOTE(mrry): This method does not transfer ownership of `ctx`, and it may
952 // be deleted before `done` is called. Take care not to capture `ctx` in any
953 // code that may execute asynchronously in this function.
954 OwnedArgsCallFrame* frame = new OwnedArgsCallFrame(
955 std::move(args), &captured_func_->captured_inputs(), ret_types_);
956
957 FunctionLibraryRuntime::Options f_opts;
958 ResourceMgr* resource_mgr = lib_->device()->resource_manager();
959 ScopedStepContainer* step_container = new ScopedStepContainer(
960 f_opts.step_id, [resource_mgr](const string& name) {
961 resource_mgr->Cleanup(name).IgnoreError();
962 });
963 f_opts.step_container = step_container;
964 f_opts.runner = ctx->runner();
965 f_opts.create_rendezvous = ShouldCreateRendezvous();
966 auto cancellation_manager =
967 std::make_unique<CancellationManager>(ctx->cancellation_manager());
968 f_opts.cancellation_manager = cancellation_manager.get();
969 f_opts.collective_executor = ctx->collective_executor();
970
971 std::shared_ptr<SimpleStepStatsCollector> stats_collector;
972 if (node || ctx->stats_aggregator()) {
973 stats_collector = std::make_shared<SimpleStepStatsCollector>();
974 }
975 const bool collect_usage = node && ctx->model();
976 f_opts.stats_collector = stats_collector.get();
977
978 // Transfer ownership of the cancellation manager to `callback`.
979 CancellationManager* raw_cancellation_manager =
980 cancellation_manager.release();
981 auto callback = std::bind(
982 [this, rets, step_container, raw_cancellation_manager, frame, node,
983 collect_usage](
984 const FunctionLibraryRuntime::DoneCallback& done,
985 IteratorContext* ctx,
986 const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
987 // Begin unbound arguments.
988 Status s) {
989 delete step_container;
990 delete raw_cancellation_manager;
991 if (s.ok()) {
992 s = frame->ConsumeRetvals(rets);
993 }
994 delete frame;
995 if (node) {
996 // TODO(b/129085499) Utilize the `node_name` which would be unique
997 // than the prefix for the function execution time statistics.
998 // prefix_with_func_name would then be node_name + func_name.
999 if (ctx->stats_aggregator()) {
1000 string prefix_with_func_name =
1001 strings::StrCat(node->name(), stats_utils::kDelimiter,
1002 captured_func_->func().name());
1003 ctx->stats_aggregator()->AddToHistogram(
1004 stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
1005 {static_cast<float>(stats_collector->processing_time())},
1006 node->num_elements());
1007 }
1008 node->add_processing_time(stats_collector->processing_time());
1009 }
1010 if (collect_usage) {
1011 node->record_start(EnvTime::NowNanos());
1012 }
1013 done(s);
1014 if (collect_usage) {
1015 node->record_stop(EnvTime::NowNanos());
1016 }
1017 },
1018 std::move(done), ctx, std::move(stats_collector), std::placeholders::_1);
1019
1020 profiler::TraceMe activity(
1021 [&] {
1022 return profiler::TraceMeEncode("InstantiatedCapturedFunction::RunAsync",
1023 {{"id", f_opts.step_id}});
1024 },
1025 profiler::TraceMeLevel::kInfo);
1026 // Stop the usage collection before calling `Run()` because `callback` may
1027 // be executed synchronously, and so the `node->record_start()` call within
1028 // `callback` would violate nesting.
1029 if (collect_usage) node->record_stop(EnvTime::NowNanos());
1030 lib_->Run(f_opts, f_handle_, frame, std::move(callback));
1031 if (collect_usage) node->record_start(EnvTime::NowNanos());
1032}
1033
1034bool InstantiatedCapturedFunction::ShouldCreateRendezvous() const {
1035 // Rendezvous should only be created by the FLR for non-CPU single-device
1036 // functions. For multi-device functions the appropriate rendezvous will be
1037 // created by the process FLR.
1038 return lib_->device()->device_type() != DEVICE_CPU && !is_multi_device_;
1039}
1040
1041} // namespace data
1042} // namespace tensorflow
1043