1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/function.h"
17
18#include <deque>
19#include <vector>
20
21#include "absl/algorithm/container.h"
22#include "absl/memory/memory.h"
23#include "absl/strings/str_cat.h"
24#include "tensorflow/core/common_runtime/device.h"
25#include "tensorflow/core/common_runtime/executor.h"
26#include "tensorflow/core/common_runtime/executor_factory.h"
27#include "tensorflow/core/common_runtime/gradients.h"
28#include "tensorflow/core/common_runtime/graph_constructor.h"
29#include "tensorflow/core/common_runtime/graph_optimizer.h"
30#include "tensorflow/core/common_runtime/inline_function_utils.h"
31#include "tensorflow/core/common_runtime/memory_types.h"
32#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
33#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
34#include "tensorflow/core/common_runtime/single_threaded_executor.h"
35#include "tensorflow/core/framework/collective.h"
36#include "tensorflow/core/framework/function.h"
37#include "tensorflow/core/framework/function_handle_cache.h"
38#include "tensorflow/core/framework/metrics.h"
39#include "tensorflow/core/framework/node_def.pb.h"
40#include "tensorflow/core/framework/node_def_util.h"
41#include "tensorflow/core/framework/op.h"
42#include "tensorflow/core/framework/op_kernel.h"
43#include "tensorflow/core/framework/versions.pb.h"
44#include "tensorflow/core/graph/algorithm.h"
45#include "tensorflow/core/graph/control_flow.h"
46#include "tensorflow/core/graph/node_builder.h"
47#include "tensorflow/core/graph/optimizer_cse.h"
48#include "tensorflow/core/lib/core/threadpool.h"
49#include "tensorflow/core/lib/gtl/map_util.h"
50#include "tensorflow/core/platform/macros.h"
51#include "tensorflow/core/platform/str_util.h"
52#include "tensorflow/core/profiler/lib/connected_traceme.h"
53#include "tensorflow/core/profiler/lib/traceme.h"
54#include "tensorflow/core/protobuf/config.pb.h"
55
56// See core/kernels/function_ops.cc for related kernels.
57
58namespace tensorflow {
59
60// A few string constant used throughout this module.
61static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
62static constexpr const char* const kDeviceArgOp =
63 FunctionLibraryDefinition::kDeviceArgOp;
64static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
65static constexpr const char* const kDeviceRetOp =
66 FunctionLibraryDefinition::kDeviceRetOp;
67static constexpr const char* const kGradientOp =
68 FunctionLibraryDefinition::kGradientOp;
69static constexpr const char* const kNodeLabel = "Func";
70static constexpr const char* const kFuncAttr =
71 FunctionLibraryDefinition::kFuncAttr;
72
73// Represents the index-th output of a node.
74struct Endpoint {
75 Node* node;
76 int index;
77
78 // Returns the string name represents this endpoint.
79 string name() const {
80 if (index == 0) {
81 return node->name();
82 } else {
83 return strings::StrCat(node->name(), ":", index);
84 }
85 }
86
87 DataType dtype() const { return node->output_type(index); }
88};
89
90struct EndpointHash {
91 uint64 operator()(const Endpoint& x) const {
92 return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
93 x.index);
94 }
95};
96
97struct EndpointEq {
98 bool operator()(const Endpoint& x, const Endpoint& y) const {
99 return (x.node == y.node) && (x.index == y.index);
100 }
101};
102
103// The following Add* routines are used to add a few graph nodes while
104// functions are transformed.
105static Node* AddArg(Graph* g, DataType dtype, int index) {
106 DCHECK_LT(0, dtype);
107 DCHECK_LT(dtype, DT_FLOAT_REF);
108 NodeDef ndef;
109 ndef.set_name(g->NewName(kNodeLabel));
110 ndef.set_op(kArgOp);
111 AddNodeAttr("T", dtype, &ndef);
112 AddNodeAttr("index", index, &ndef);
113 Status s;
114 Node* ret = g->AddNode(ndef, &s);
115 TF_CHECK_OK(s);
116 return ret;
117}
118
119static Node* AddRet(Graph* g, Endpoint input, int index) {
120 DCHECK_LT(0, input.dtype());
121 DCHECK_LT(input.dtype(), DT_FLOAT_REF);
122 NodeDef ndef;
123 ndef.set_name(g->NewName(kNodeLabel));
124 ndef.set_op(kRetOp);
125 ndef.add_input(input.name());
126 AddNodeAttr("T", input.dtype(), &ndef);
127 AddNodeAttr("index", index, &ndef);
128 Status s;
129 Node* ret = g->AddNode(ndef, &s);
130 TF_CHECK_OK(s);
131 g->AddEdge(input.node, input.index, ret, 0);
132 return ret;
133}
134
135// FunctionLibraryRuntime implementation that forwards all the function calls to
136// the base runtime implementation, and only overrides FunctionLibraryDefinition
137// in calls to Instantiate (if caller doesn't provide the
138// InstantiateOptions::lib_def option).
139//
140// When the function library runtime (FunctionLibraryRuntimeImpl specifically)
141// instantiates a function into a Graph object, it also creates an Executor for
142// it. That executor has a pointer to the function library runtime instance,
143// that is used to instantiate all nested function calls.
144//
145// The function library definition used to instantiate the function must be
146// preserved in the Executor's function library runtime.
147//
148// IMPORTANT: This runtime is intended for use only in executors created for
149// functions instantiated into a graph in FunctionLibraryRuntimeImpl.
150class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime {
151 public:
152 FunctionLibraryRuntimeOverlay(FunctionLibraryRuntime* base_flr,
153 const FunctionLibraryDefinition* lib_def)
154 : base_flr_(base_flr), lib_def_(lib_def) {}
155 ~FunctionLibraryRuntimeOverlay() override;
156
157 Status Instantiate(const string& function_name, AttrSlice attrs,
158 const InstantiateOptions& options,
159 Handle* handle) override;
160
161 Status ReleaseHandle(Handle handle) override;
162
163 const FunctionBody* GetFunctionBody(Handle h) override;
164
165 Status GetRetTypes(Handle h, DataTypeVector* ret_types) override;
166
167 void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
168 std::vector<Tensor>* rets, DoneCallback done) override;
169
170 void Run(const Options& opts, Handle handle, CallFrameInterface* call_frame,
171 DoneCallback done) override;
172
173 Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
174 std::vector<Tensor>* rets) override;
175
176 Status RunSync(Options opts, Handle handle,
177 CallFrameInterface* frame) override;
178
179 Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
180 OpKernel** kernel) override;
181
182 bool IsStateful(const string& function_name) const override;
183
184 const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
185 const override;
186
187 Env* env() override;
188 const ConfigProto* const config_proto() override;
189 Device* device() override;
190 const Device* device() const override;
191 std::function<void(std::function<void()>)>* runner() override;
192 const DeviceMgr* device_mgr() const override;
193
194 string DebugString(Handle handle) override;
195 int graph_def_version() const override;
196
197 Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
198 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
199 FunctionLibraryRuntime** out_flr,
200 bool skip_flib_def = false) override;
201
202 private:
203 FunctionLibraryRuntime* base_flr_; // not owned
204 const FunctionLibraryDefinition* lib_def_; // not owned
205};
206
207FunctionLibraryRuntimeOverlay::~FunctionLibraryRuntimeOverlay() = default;
208
209Status FunctionLibraryRuntimeOverlay::Instantiate(
210 const string& function_name, AttrSlice attrs,
211 const InstantiateOptions& options, Handle* handle) {
212 // We automatically set the `lib_def` option for all instantiations, if the
213 // caller doesn't set this option explicitly.
214 if (!options.lib_def && lib_def_) {
215 InstantiateOptions options_copy = options;
216 options_copy.lib_def = lib_def_;
217 return base_flr_->Instantiate(function_name, attrs, options_copy, handle);
218 } else {
219 return base_flr_->Instantiate(function_name, attrs, options, handle);
220 }
221}
222
223Status FunctionLibraryRuntimeOverlay::ReleaseHandle(Handle handle) {
224 return base_flr_->ReleaseHandle(handle);
225}
226
227const FunctionBody* FunctionLibraryRuntimeOverlay::GetFunctionBody(Handle h) {
228 return base_flr_->GetFunctionBody(h);
229}
230
231Status FunctionLibraryRuntimeOverlay::GetRetTypes(Handle h,
232 DataTypeVector* ret_types) {
233 return base_flr_->GetRetTypes(h, ret_types);
234}
235
236void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle,
237 gtl::ArraySlice<Tensor> args,
238 std::vector<Tensor>* rets,
239 DoneCallback done) {
240 base_flr_->Run(opts, handle, args, rets, std::move(done));
241}
242
243void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle,
244 CallFrameInterface* call_frame,
245 DoneCallback done) {
246 base_flr_->Run(opts, handle, call_frame, std::move(done));
247}
248
249Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
250 gtl::ArraySlice<Tensor> args,
251 std::vector<Tensor>* rets) {
252 return base_flr_->RunSync(std::move(opts), handle, args, rets);
253}
254
255Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
256 CallFrameInterface* call_frame) {
257 return base_flr_->RunSync(std::move(opts), handle, call_frame);
258}
259
260Status FunctionLibraryRuntimeOverlay::CreateKernel(
261 const std::shared_ptr<const NodeProperties>&, OpKernel**) {
262 // We don't have access to base_lib_def_ in base function library runtime (aka
263 // FunctionLibraryRuntimeImpl), so to make sure we do not create a kernel with
264 // the wrong lib_def we just disable creation of new kernels through overlays.
265 //
266 // When we call Instantiate from the base runtime with the lib_def option,
267 // the base runtime implementation is responsible for correctly passing it
268 // through to all kernel constructions.
269 return errors::Internal(
270 "Overlay function library runtime doesn't support kernel creation.");
271}
272
273bool FunctionLibraryRuntimeOverlay::IsStateful(
274 const string& function_name) const {
275 // Important: we do not forward lookup to the base FLR.
276 const OpDef* op_def;
277 const Status s = lib_def_->LookUpOpDef(function_name, &op_def);
278 return s.ok() && op_def->is_stateful();
279}
280
281Env* FunctionLibraryRuntimeOverlay::env() { return base_flr_->env(); }
282
283const ConfigProto* const FunctionLibraryRuntimeOverlay::config_proto() {
284 return base_flr_->config_proto();
285}
286
287Device* FunctionLibraryRuntimeOverlay::device() { return base_flr_->device(); }
288
289const Device* FunctionLibraryRuntimeOverlay::device() const {
290 return base_flr_->device();
291}
292
293std::function<void(std::function<void()>)>*
294FunctionLibraryRuntimeOverlay::runner() {
295 return base_flr_->runner();
296}
297
298const DeviceMgr* FunctionLibraryRuntimeOverlay::device_mgr() const {
299 return base_flr_->device_mgr();
300}
301
302const FunctionLibraryDefinition*
303FunctionLibraryRuntimeOverlay::GetFunctionLibraryDefinition() const {
304 return lib_def_ ? lib_def_ : base_flr_->GetFunctionLibraryDefinition();
305}
306
307string FunctionLibraryRuntimeOverlay::DebugString(Handle handle) {
308 return base_flr_->DebugString(handle);
309}
310
311int FunctionLibraryRuntimeOverlay::graph_def_version() const {
312 return base_flr_->graph_def_version();
313}
314
315Status FunctionLibraryRuntimeOverlay::Clone(
316 std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
317 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
318 FunctionLibraryRuntime** out_flr, bool skip_flib_def) {
319 // NOTE(ezhulenev): The cloned FunctionLibraryRuntime will be missing the
320 // FunctionLibraryDefinition override, but that's ok because we anyway do not
321 // copy / clone instantiated items from the base FLR.
322 return base_flr_->Clone(out_lib_def, out_pflr, out_flr, skip_flib_def);
323}
324
325class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
326 public:
327 FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env,
328 const ConfigProto* config, Device* device,
329 int graph_def_version,
330 const FunctionLibraryDefinition* lib_def,
331 thread::ThreadPool* default_thread_pool,
332 const OptimizerOptions& optimizer_options,
333 const SessionMetadata* session_metadata,
334 ProcessFunctionLibraryRuntime* parent);
335
336 ~FunctionLibraryRuntimeImpl() override;
337
338 Status Instantiate(const string& function_name, AttrSlice attrs,
339 const InstantiateOptions& options,
340 Handle* handle) override;
341
342 Status ReleaseHandle(Handle handle) override;
343
344 const FunctionBody* GetFunctionBody(Handle handle) override;
345
346 Status GetRetTypes(Handle handle, DataTypeVector* ret_types) override;
347
348 Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
349 OpKernel** kernel) override;
350
351 void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
352 std::vector<Tensor>* rets, DoneCallback done) override;
353 void Run(const Options& opts, Handle handle, CallFrameInterface* frame,
354 DoneCallback done) override;
355 Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
356 std::vector<Tensor>* rets) override;
357 Status RunSync(Options opts, Handle handle,
358 CallFrameInterface* call_frame) override;
359
360 bool IsStateful(const string& function) const override;
361
362 const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
363 const override {
364 return base_lib_def_;
365 }
366
367 Device* device() override { return device_; }
368 const Device* device() const override { return device_; }
369
370 std::function<void(std::function<void()>)>* runner() override {
371 return &default_runner_;
372 }
373
374 const DeviceMgr* device_mgr() const override { return device_mgr_; }
375 Env* env() override { return env_; }
376 const ConfigProto* const config_proto() override { return config_; }
377 int graph_def_version() const override { return graph_def_version_; }
378
379 string DebugString(Handle h) override;
380
381 Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
382 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
383 FunctionLibraryRuntime** out_flr,
384 bool skip_flib_def = false) override;
385
386 private:
387 typedef FunctionLibraryRuntimeImpl ME;
388
389 const DeviceMgr* const device_mgr_;
390 Device* const device_;
391 Env* const env_;
392 const ConfigProto* const config_;
393 const int graph_def_version_;
394 const FunctionLibraryDefinition* const base_lib_def_;
395 GraphOptimizer optimizer_;
396 const SessionMetadata* const session_metadata_;
397 Executor::Args::Runner default_runner_;
398 const string device_name_;
399
400 std::function<Status(const string&, const OpDef**)> get_func_sig_;
401 std::function<Status(const std::shared_ptr<const NodeProperties>&,
402 OpKernel**)>
403 create_kernel_;
404
405 mutable mutex mu_;
406
407 int next_handle_ TF_GUARDED_BY(mu_);
408
409 // The instantiated and transformed function is encoded as a Graph
410 // object, and an executor is created for the graph.
411 struct Item {
412 uint64 instantiation_counter = 0;
413 std::unique_ptr<const Graph> graph = nullptr;
414 const FunctionLibraryDefinition* lib_def = nullptr; // Not owned.
415 FunctionBody* func_graph = nullptr;
416 Executor* exec = nullptr;
417 FunctionLibraryRuntimeOverlay* overlay_flr = nullptr;
418 string executor_type;
419 bool allow_small_function_optimizations = false;
420 bool allow_control_flow_sync_execution = false;
421
422 ~Item() {
423 delete this->func_graph;
424 delete this->exec;
425 delete this->overlay_flr;
426 }
427 };
428 std::unique_ptr<absl::flat_hash_map<Handle, std::unique_ptr<Item>>> items_
429 TF_GUARDED_BY(mu_);
430 std::unique_ptr<FunctionHandleCache> function_handle_cache_;
431 ProcessFunctionLibraryRuntime* parent_ = nullptr; // not owned.
432
433 // Overloads the CreateKernel method, providing a FunctionLibraryRuntime
434 // to use for kernel creation and execution. In particular, this method can
435 // accept a FunctionLibraryRuntimeOverlay that overlays a different
436 // FunctionLibraryDefinition.
437 Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
438 FunctionLibraryRuntime* flr, OpKernel** kernel);
439 Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs,
440 const FunctionLibraryDefinition* lib_def,
441 std::unique_ptr<FunctionBody>* fbody);
442 Status CreateItem(Item** item);
443 Status GetOrCreateItem(LocalHandle local_handle, Item** item);
444 Status InstantiateSymbolicGradient(const NameAttrList& func,
445 const FunctionLibraryDefinition* lib_def,
446 std::unique_ptr<FunctionBody>* g_body);
447 bool IsLocalTarget(const InstantiateOptions& options) const;
448 AttrValueMap FixAttrs(const AttrSlice& attrs);
449 void RunRemote(const Options& opts, Handle handle,
450 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
451 Item* item, DoneCallback done);
452
453 // TODO(fishx): Avoid using std::unique_ptr for PrivateIntraProcessRendezvous,
454 // since it will allocate the object on heap.
455 Status PrepareRunSync(
456 Handle handle, Options* run_opts, Item** out_item,
457 std::unique_ptr<PrivateIntraProcessRendezvous>* out_rendezvous);
458
459 void ExecutorArgsFromOptions(const FunctionLibraryRuntime::Options& run_opts,
460 CallFrameInterface* frame,
461 Executor::Args* exec_args);
462
463 TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
464};
465
466FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
467 const DeviceMgr* dmgr, Env* env, const ConfigProto* config, Device* device,
468 int graph_def_version, const FunctionLibraryDefinition* lib_def,
469 thread::ThreadPool* default_thread_pool,
470 const OptimizerOptions& optimizer_options,
471 const SessionMetadata* session_metadata,
472 ProcessFunctionLibraryRuntime* parent)
473 : device_mgr_(dmgr),
474 device_(device),
475 env_(env),
476 config_(config),
477 graph_def_version_(graph_def_version),
478 base_lib_def_(lib_def),
479 optimizer_(optimizer_options),
480 session_metadata_(session_metadata),
481 default_runner_(nullptr),
482 device_name_(device_ == nullptr
483 ? ProcessFunctionLibraryRuntime::kDefaultFLRDevice
484 : device_->name()),
485 next_handle_(0),
486 items_(std::make_unique<
487 absl::flat_hash_map<Handle, std::unique_ptr<Item>>>()),
488 function_handle_cache_(std::make_unique<FunctionHandleCache>(this)),
489 parent_(parent) {
490 get_func_sig_ = [this](const string& op, const OpDef** sig) {
491 return base_lib_def_->LookUpOpDef(op, sig);
492 };
493 create_kernel_ = [this](const std::shared_ptr<const NodeProperties>& props,
494 OpKernel** kernel) {
495 return CreateKernel(props, kernel);
496 };
497 thread::ThreadPool* pool = nullptr;
498 if (device_ != nullptr) {
499 pool = device_->tensorflow_device_thread_pool();
500 }
501 if (pool == nullptr) {
502 pool = default_thread_pool;
503 }
504 if (pool != nullptr) {
505 default_runner_ = [pool](Executor::Args::Closure c) {
506 pool->Schedule(std::move(c));
507 };
508 }
509}
510
511FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {
512 // Deleting the items_ list will delete all the function handles registered in
513 // this object. A function may contains a few sub-functions which have also
514 // been registered in this object. Deleting the parent function will call
515 // ReleaseHandle in this class again for each of the sub-functions. These
516 // circular calls may cause segfault since the items_ may have already been
517 // partially deleted when releasing handles of sub-functions. Explicitly
518 // release items_ here and check it in ReleaseHandle to avoid this.
519 items_.reset();
520}
521
522// An asynchronous op kernel which executes an instantiated function
523// defined in a library.
524class CallOp : public AsyncOpKernel {
525 public:
526 CallOp(FunctionLibraryRuntime::Handle handle, OpKernelConstruction* ctx)
527 : AsyncOpKernel(ctx), handle_(handle) {}
528
529 ~CallOp() override {
530 // TODO(iga): Release the cached handle_
531 }
532
533 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
534 FunctionLibraryRuntime* lib = ctx->function_library();
535 OP_REQUIRES_ASYNC(ctx, lib != nullptr,
536 errors::Internal("No function library is provided."),
537 done);
538 FunctionLibraryRuntime::Options opts;
539 opts.rendezvous = ctx->rendezvous();
540 opts.cancellation_manager = ctx->cancellation_manager();
541 opts.step_container = ctx->step_container();
542 opts.stats_collector = ctx->stats_collector();
543 opts.runner = ctx->runner();
544 opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
545 opts.collective_executor = ctx->collective_executor();
546 opts.stack_trace = ctx->stack_trace();
547 std::vector<Tensor> args;
548 args.reserve(ctx->num_inputs());
549 for (int i = 0; i < ctx->num_inputs(); ++i) {
550 args.push_back(ctx->input(i));
551 }
552 std::vector<Tensor>* rets = new std::vector<Tensor>;
553 lib->Run(opts, handle_, args, rets,
554 [ctx, done, rets](const Status& status) {
555 if (!status.ok()) {
556 ctx->SetStatus(status);
557 } else {
558 const int ret_size = static_cast<int>(rets->size());
559 CHECK_EQ(ret_size, ctx->num_outputs());
560 for (int i = 0; i < ret_size; ++i) {
561 ctx->set_output(i, (*rets)[i]);
562 }
563 }
564 delete rets;
565 done();
566 });
567 }
568
569 private:
570 FunctionLibraryRuntime::Handle handle_;
571
572 TF_DISALLOW_COPY_AND_ASSIGN(CallOp);
573};
574
575const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) {
576 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h);
577 if (local_handle == kInvalidLocalHandle) {
578 LOG(ERROR) << "Could not find Handle: " << h
579 << " on device: " << device_name_;
580 return nullptr;
581 }
582
583 tf_shared_lock l(mu_);
584 auto iter = items_->find(local_handle);
585 CHECK(iter != items_->end());
586 return iter->second->func_graph;
587}
588
589Status FunctionLibraryRuntimeImpl::GetRetTypes(Handle h,
590 DataTypeVector* ret_types) {
591 if (parent_->IsMultiDevice(h)) {
592 return parent_->GetRetTypes(h, ret_types);
593 }
594 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h);
595 if (local_handle == kInvalidLocalHandle) {
596 return errors::InvalidArgument("Handle ", h, " not found.");
597 }
598 const FunctionBody* fbody = GetFunctionBody(h);
599 *ret_types = fbody->ret_types;
600 return OkStatus();
601}
602
603Status FunctionLibraryRuntimeImpl::CreateKernel(
604 const std::shared_ptr<const NodeProperties>& props, OpKernel** kernel) {
605 return CreateKernel(props, this, kernel);
606}
607
608Status FunctionLibraryRuntimeImpl::CreateKernel(
609 const std::shared_ptr<const NodeProperties>& props,
610 FunctionLibraryRuntime* flr, OpKernel** kernel) {
611 // If a custom kernel creator is given, try that.
612 Status s;
613 const CustomKernelCreator* custom_kernel_creator =
614 GetDefaultCustomKernelCreator();
615 if (custom_kernel_creator &&
616 custom_kernel_creator->CanCreateKernel(*flr, props)) {
617 std::unique_ptr<OpKernel> ret;
618 s = custom_kernel_creator->CreateKernel(flr, props, &ret);
619 if (s.ok()) {
620 *kernel = ret.release();
621 } else {
622 VLOG(2) << "Custom creator error: " << s;
623 }
624 return s;
625 }
626
627 const FunctionLibraryDefinition* lib_def =
628 flr->GetFunctionLibraryDefinition();
629 if (lib_def->Find(props->node_def.op()) == nullptr) {
630 // A primitive operation. Creates the registered kernel.
631 return CreateNonCachedKernel(device_, flr, props, graph_def_version_,
632 kernel);
633 }
634
635 // Try to instantiate this function for the func/attr. Maybe it's
636 // cached already.
637 InstantiateOptions options;
638 if (lib_def != base_lib_def_) {
639 options.lib_def = lib_def;
640 }
641 Handle handle;
642 TF_RETURN_IF_ERROR(Instantiate(props->node_def.op(),
643 AttrSlice(&props->node_def.attr()), options,
644 &handle));
645
646 const FunctionBody* fbody = GetFunctionBody(handle);
647 CHECK_NOTNULL(fbody);
648
649 // TODO(zhifengc): For now, we assume int32 and resources are always on host
650 // memory and other types are always on device memory. We should do type
651 // inference over function body to derive the correct input/output memory
652 // types.
653 MemoryTypeVector input_memory_types;
654 for (const auto& t : fbody->arg_types) {
655 input_memory_types.push_back(MTypeFromDType(t));
656 }
657 MemoryTypeVector output_memory_types;
658 for (const auto& t : fbody->ret_types) {
659 output_memory_types.push_back(MTypeFromDType(t));
660 }
661
662 // Constructs a CallOp kernel for running the instantiated function.
663 auto device_type = DeviceType(device_->attributes().device_type());
664 auto new_props = std::make_shared<NodeProperties>(
665 &fbody->fdef.signature(), props->node_def, fbody->arg_types,
666 fbody->ret_types);
667 OpKernelConstruction construction(
668 device_type, device_, device_->GetAllocator(AllocatorAttributes()), flr,
669 device_->resource_manager(), props, input_memory_types,
670 output_memory_types, graph_def_version_, &s);
671 if (s.ok()) {
672 *kernel = new CallOp(handle, &construction);
673 }
674 return s;
675}
676
677Status FunctionLibraryRuntimeImpl::FunctionDefToBody(
678 const FunctionDef& fdef, AttrSlice attrs,
679 const FunctionLibraryDefinition* lib_def,
680 std::unique_ptr<FunctionBody>* fbody) {
681 if (lib_def == base_lib_def_) {
682 return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig_, fbody);
683 } else {
684 auto get_func_sig = [lib_def](const string& op, const OpDef** sig) {
685 return lib_def->LookUpOpDef(op, sig);
686 };
687 return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody);
688 }
689}
690
691Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
692 const NameAttrList& func, const FunctionLibraryDefinition* lib_def,
693 std::unique_ptr<FunctionBody>* g_body) {
694 const FunctionDef* fdef = lib_def->Find(func.name());
695 if (fdef == nullptr) {
696 // f is a primitive op.
697 gradient::Creator creator;
698 TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator));
699 if (creator == nullptr) {
700 return errors::InvalidArgument("No gradient is defined for ",
701 func.name());
702 }
703 FunctionDef grad_fdef;
704 // TODO(josh11b): Should filter out the attrs from func that aren't used
705 // by the gradient function.
706 TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
707 TF_RETURN_IF_ERROR(
708 FunctionDefToBody(grad_fdef, AttrSlice(&func.attr()), lib_def, g_body));
709 } else {
710 // f is a user-defined function.
711 InstantiateOptions options;
712 if (lib_def != base_lib_def_) {
713 options.lib_def = lib_def;
714 }
715 Handle f_handle;
716 TF_RETURN_IF_ERROR(
717 Instantiate(func.name(), AttrSlice(&func.attr()), options, &f_handle));
718 const FunctionBody* f_body = GetFunctionBody(f_handle);
719 CHECK_NOTNULL(f_body);
720 *g_body = SymbolicGradient(*f_body);
721 }
722 return OkStatus();
723}
724
725bool FunctionLibraryRuntimeImpl::IsLocalTarget(
726 const InstantiateOptions& options) const {
727 if (device_ == nullptr) return true;
728 if (options.target.empty()) return true;
729 if (options.is_multi_device_function) return false;
730 Device* target_device;
731 if (!device_mgr_->LookupDevice(options.target, &target_device).ok()) {
732 VLOG(1) << "Not instantiating function in FLR because failed to "
733 << "find device " << options.target << " in device manager";
734 return false;
735 }
736 if (target_device != device_) {
737 VLOG(1) << "Not instantiating function in FLR because target device "
738 << options.target
739 << " is different from FLR's device: " << device_->DebugString();
740 return false;
741 }
742 return true;
743}
744
745Status FunctionLibraryRuntimeImpl::Instantiate(
746 const string& function_name, AttrSlice attrs,
747 const InstantiateOptions& options, Handle* handle) {
748 if (!IsLocalTarget(options)) {
749 return parent_->Instantiate(function_name, attrs, options, handle);
750 }
751
752 if (options.use_function_cache) {
753 InstantiateOptions options_copy(options);
754 options_copy.use_function_cache = false;
755 return function_handle_cache_->Instantiate(function_name, attrs,
756 options_copy, handle);
757 }
758
759 // Since this is a local target, ensure that the local `device_name_` appears
760 // in the canonical key.
761 InstantiateOptions options_copy(options);
762 options_copy.target = device_name_;
763 const string key = Canonicalize(function_name, attrs, options_copy);
764
765 {
766 mutex_lock l(mu_);
767 *handle = parent_->GetHandle(key);
768 if (*handle != kInvalidHandle) {
769 FunctionLibraryRuntime::LocalHandle handle_on_device =
770 parent_->GetHandleOnDevice(device_name_, *handle);
771 if (handle_on_device == kInvalidLocalHandle) {
772 return errors::Internal("LocalHandle not found for handle ", *handle,
773 ".");
774 }
775 auto item_handle = items_->find(handle_on_device);
776 if (item_handle == items_->end()) {
777 return errors::Internal("LocalHandle ", handle_on_device,
778 " for handle ", *handle,
779 " not found in items.");
780 }
781 ++item_handle->second->instantiation_counter;
782 return OkStatus();
783 }
784 }
785
786 const FunctionLibraryDefinition* lib_def =
787 options.lib_def ? options.lib_def : base_lib_def_;
788 std::unique_ptr<FunctionBody> fbody;
789 if (function_name == kGradientOp) {
790 const AttrValue* f = attrs.Find(kFuncAttr);
791 if (f == nullptr) {
792 return errors::InvalidArgument("SymbolicGradient is missing attr: f");
793 }
794 const auto& func = f->func();
795 if (func.name() == kGradientOp) {
796 return errors::InvalidArgument("Can't take gradient of SymbolicGradient");
797 }
798 const string grad = lib_def->FindGradient(func.name());
799 if (!grad.empty()) {
800 return Instantiate(grad, AttrSlice(&func.attr()), options, handle);
801 }
802 TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, lib_def, &fbody));
803 } else {
804 const FunctionDef* fdef = lib_def->Find(function_name);
805 if (fdef == nullptr) {
806 return errors::NotFound("Function ", function_name, " is not defined.");
807 }
808 TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, lib_def, &fbody));
809 }
810
811 LocalHandle local_handle;
812 {
813 mutex_lock l(mu_);
814 *handle = parent_->GetHandle(key);
815 if (*handle != kInvalidHandle) {
816 local_handle = parent_->GetHandleOnDevice(device_name_, *handle);
817 ++(*items_)[local_handle]->instantiation_counter;
818 } else {
819 *handle = parent_->AddHandle(key, device_name_, next_handle_);
820 Item* item = new Item;
821 item->func_graph = fbody.release();
822 item->instantiation_counter = 1;
823 item->executor_type = ExecutorType(options, attrs);
824 item->allow_small_function_optimizations =
825 options.allow_small_function_optimizations;
826 item->allow_control_flow_sync_execution =
827 options.allow_control_flow_sync_execution;
828 if (options.lib_def) {
829 item->overlay_flr =
830 new FunctionLibraryRuntimeOverlay(this, options.lib_def);
831 }
832 local_handle = next_handle_++;
833 items_->emplace(local_handle, std::unique_ptr<Item>(item));
834 }
835 }
836
837 if (options.create_kernels_eagerly) {
838 Item* item;
839 TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, &item));
840 }
841
842 return OkStatus();
843}
844
845Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
846 LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
847 if (h == kInvalidLocalHandle) {
848 return parent_->ReleaseHandle(handle);
849 }
850 std::unique_ptr<Item> item_to_delete;
851 Status parent_status;
852 {
853 mutex_lock l(mu_);
854 // Return directly if all items has already been released.
855 if (items_ == nullptr) return OkStatus();
856
857 auto it = items_->find(h);
858 if (it == items_->end()) {
859 return errors::Internal(
860 "Inconsistent FunctionLibraryRuntime. Expected to find an item for "
861 "handle ",
862 h, " but found none");
863 }
864 std::unique_ptr<Item>& item = it->second;
865 --item->instantiation_counter;
866 if (item->instantiation_counter == 0) {
867 // We don't simply erase h's item because that would trigger
868 // item destruction while holding mu_. Item destruction can
869 // trigger graph destruction. If the graph contains kernels like
870 // CallOp or PartitionCallOp, their destructors will release cached
871 // function handles, resulting in deadlock here.
872 item_to_delete = std::move(item);
873 items_->erase(h);
874 parent_status = parent_->RemoveHandle(handle);
875 }
876 }
877 return parent_status;
878}
879
880namespace {
881
882// Removes all stateless nodes that do not contribute to a return
883// value from the function body. Unlike `RemoveDeadNodes()`, which is
884// triggered by `OptimizerOptions.do_function_inlining`, this pass
885// ignores the SINK node, from which (by definition) all nodes are
886// reverse reachable, and preserves all nodes that are reachable from
887// control output nodes.
888//
889// TODO(ezhulenev, skyewm): Function body should not have special treatment of
890// stateful ops, graph should encode nodes that must execute with `control_ret`
891// and `control_output`.
892void PruneFunctionBody(const FunctionDef& fdef, Graph* g) {
893 VLOG(2) << "Pruning function body: function_name=" << fdef.signature().name();
894
895 // `control_ret` nodes must be always executed.
896 std::unordered_set<StringPiece, StringPieceHasher> control_ret_nodes;
897 for (const auto& control_ret : fdef.control_ret()) {
898 control_ret_nodes.insert(control_ret.second);
899 }
900
901 std::unordered_set<const Node*> nodes;
902 for (auto n : g->nodes()) {
903 // NOTE(mrry): "_Retval" nodes are stateful, and so will be added
904 // to the seed set of `nodes`. "_Arg" nodes are also stateful, but we
905 // specifically exclude them as seeds, to avoid unconditionally executing
906 // unused argument nodes (e.g. in a function like `lambda x, y: y`).
907 // TODO(mrry): Investigate whether the `n->IsControlFlow()` test is
908 // still needed. It would be preferable to prune entire loops and/or
909 // conditionals if they are not used in the graph.
910 if (n->IsControlFlow() ||
911 (n->op_def().is_stateful() && n->type_string() != kArgOp) ||
912 (control_ret_nodes.find(n->name()) != control_ret_nodes.end())) {
913 nodes.insert(n);
914 }
915 }
916 bool changed = PruneForReverseReachability(g, std::move(nodes));
917 if (changed) {
918 FixupSourceAndSinkEdges(g);
919 }
920}
921
922} // namespace
923
924Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) {
925 const FunctionBody* fbody;
926 FunctionLibraryRuntime* flr;
927 string executor_type;
928 {
929 tf_shared_lock l(mu_);
930 fbody = (*item)->func_graph;
931 flr = (*item)->overlay_flr
932 ? static_cast<FunctionLibraryRuntime*>((*item)->overlay_flr)
933 : static_cast<FunctionLibraryRuntime*>(this);
934 executor_type = (*item)->executor_type;
935 }
936 const FunctionLibraryDefinition* lib_def =
937 flr->GetFunctionLibraryDefinition();
938 auto g = std::make_unique<Graph>(lib_def);
939 CopyGraph(*fbody->graph, g.get());
940
941 PruneFunctionBody(fbody->fdef, g.get());
942 optimizer_.Optimize(this, env(), device(), &g, GraphOptimizer::Options());
943 TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()),
944 device()->name(), g.get()));
945
946 // Creates an executor based on the g. This must be done without
947 // holding mu_ because create_kernel_ calls back into the library.
948 LocalExecutorParams params;
949 params.device = device_;
950 params.function_library = flr;
951 params.allow_control_flow_sync_execution =
952 (*item)->allow_control_flow_sync_execution;
953 if (flr == this) {
954 params.create_kernel = create_kernel_;
955 } else {
956 params.create_kernel =
957 [this, flr](const std::shared_ptr<const NodeProperties>& props,
958 OpKernel** kernel) {
959 return CreateKernel(props, flr, kernel);
960 };
961 }
962 params.delete_kernel = [](OpKernel* kernel) {
963 DeleteNonCachedKernel(kernel);
964 };
965 params.session_metadata = session_metadata_;
966 std::unique_ptr<Executor> exec;
967
968 // When the instantiation options request small function optimizations, all
969 // graphs which are safe for synchronous execution will set this flag to true:
970 if ((*item)->allow_small_function_optimizations && executor_type.empty()) {
971 executor_type = "SINGLE_THREADED_EXECUTOR";
972 }
973
974 metrics::IncrementTestCounter("flr_executor",
975 (executor_type == "SINGLE_THREADED_EXECUTOR")
976 ? "single_threaded"
977 : "default");
978
979 TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, *g, &exec));
980 {
981 // Guard item since it is already inserted in items_.
982 mutex_lock l(mu_);
983 if ((*item)->exec == nullptr) {
984 (*item)->graph = std::move(g);
985 (*item)->exec = exec.release();
986 }
987 }
988 return OkStatus();
989}
990
991Status FunctionLibraryRuntimeImpl::GetOrCreateItem(LocalHandle local_handle,
992 Item** item) {
993 {
994 tf_shared_lock l(mu_);
995 auto iter = items_->find(local_handle);
996 if (iter == items_->end()) {
997 return errors::Internal("Local function handle ", local_handle,
998 " is not valid. Likely an internal error.");
999 }
1000 *item = iter->second.get();
1001 if ((*item)->exec != nullptr) {
1002 return OkStatus();
1003 }
1004 }
1005 // NOTE: We need to call CreateItem out of mu_ because creating an
1006 // executor needs to call CreateKernel.
1007 return CreateItem(item);
1008}
1009
1010void FunctionLibraryRuntimeImpl::ExecutorArgsFromOptions(
1011 const FunctionLibraryRuntime::Options& run_opts, CallFrameInterface* frame,
1012 Executor::Args* exec_args) {
1013 // Inherit the step_id from the caller.
1014 exec_args->step_id = run_opts.step_id;
1015 exec_args->rendezvous = run_opts.rendezvous;
1016 exec_args->stats_collector = run_opts.stats_collector;
1017 exec_args->cancellation_manager = run_opts.cancellation_manager;
1018 exec_args->step_container = run_opts.step_container;
1019 if (run_opts.runner) {
1020 exec_args->runner = *run_opts.runner;
1021 } else {
1022 exec_args->runner = default_runner_;
1023 }
1024 exec_args->collective_executor = run_opts.collective_executor;
1025 exec_args->call_frame = frame;
1026 exec_args->run_all_kernels_inline = run_opts.run_all_kernels_inline;
1027 exec_args->user_intra_op_threadpool = run_opts.user_intra_op_threadpool;
1028 exec_args->coordination_service_agent = run_opts.coordination_service_agent;
1029 exec_args->stack_trace = run_opts.stack_trace;
1030}
1031
1032void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
1033 gtl::ArraySlice<Tensor> args,
1034 std::vector<Tensor>* rets,
1035 Item* item, DoneCallback done) {
1036 string target_device = parent_->GetDeviceName(handle);
1037 string source_device = opts.source_device;
1038 RendezvousInterface* rendezvous = opts.rendezvous;
1039 DeviceContext* device_context;
1040 Status s = parent_->GetDeviceContext(target_device, &device_context);
1041 if (!s.ok()) {
1042 done(s);
1043 return;
1044 }
1045 int64_t src_incarnation, target_incarnation;
1046 s = parent_->GetDeviceIncarnation(source_device, &src_incarnation);
1047 s.Update(parent_->GetDeviceIncarnation(target_device, &target_incarnation));
1048 if (!s.ok()) {
1049 done(s);
1050 return;
1051 }
1052
1053 const FunctionBody* fbody = GetFunctionBody(handle);
1054 FunctionCallFrame* frame =
1055 new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
1056 Executor::Args* exec_args = new Executor::Args;
1057 ExecutorArgsFromOptions(opts, frame, exec_args);
1058
1059 std::vector<AllocatorAttributes> args_alloc_attrs, rets_alloc_attrs;
1060 args_alloc_attrs.reserve(fbody->arg_types.size());
1061 rets_alloc_attrs.reserve(fbody->ret_types.size());
1062 // Note: Functions assume that int32's are always on host memory.
1063 for (const auto& arg_type : fbody->arg_types) {
1064 AllocatorAttributes arg_alloc_attrs;
1065 if (MTypeFromDType(arg_type) == HOST_MEMORY) {
1066 arg_alloc_attrs.set_on_host(true);
1067 }
1068 args_alloc_attrs.push_back(arg_alloc_attrs);
1069 }
1070 for (const auto& ret_type : fbody->ret_types) {
1071 AllocatorAttributes ret_alloc_attrs;
1072 if (MTypeFromDType(ret_type) == HOST_MEMORY) {
1073 ret_alloc_attrs.set_on_host(true);
1074 }
1075 rets_alloc_attrs.push_back(ret_alloc_attrs);
1076 }
1077
1078 bool allow_dead_tensors = opts.allow_dead_tensors;
1079
1080 // The ProcFLR sends the arguments to the function from the source_device to
1081 // the target_device. So here we receive those arguments. Similarly, when the
1082 // computation is done and stored in *rets, we send the return values back
1083 // to the source_device (caller) so that the ProcFLR can receive them later.
1084 std::vector<Tensor>* remote_args = new std::vector<Tensor>;
1085 ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
1086 source_device, target_device, "arg_", src_incarnation, args.size(),
1087 device_context, args_alloc_attrs, rendezvous, remote_args,
1088 [frame, remote_args, item, source_device, target_device,
1089 target_incarnation, rendezvous, device_context, rets, done, exec_args,
1090 rets_alloc_attrs, allow_dead_tensors](const Status& status) {
1091 Status s = status;
1092 if (s.ok()) {
1093 s = frame->SetArgs(*remote_args);
1094 }
1095 if (!s.ok()) {
1096 delete frame;
1097 delete remote_args;
1098 delete exec_args;
1099 done(s);
1100 return;
1101 }
1102 item->exec->RunAsync(
1103 *exec_args,
1104 [frame, rets, done, source_device, target_device,
1105 target_incarnation, rendezvous, device_context, remote_args,
1106 rets_alloc_attrs, allow_dead_tensors](const Status& status) {
1107 Status s = status;
1108 if (s.ok()) {
1109 s = frame->ConsumeRetvals(rets, allow_dead_tensors);
1110 }
1111 delete frame;
1112 if (!s.ok()) {
1113 delete remote_args;
1114 done(s);
1115 return;
1116 }
1117 s = ProcessFunctionLibraryRuntime::SendTensors(
1118 target_device, source_device, "ret_", target_incarnation,
1119 *rets, device_context, rets_alloc_attrs, rendezvous);
1120 delete remote_args;
1121 done(s);
1122 });
1123 delete exec_args;
1124 });
1125}
1126
1127void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
1128 gtl::ArraySlice<Tensor> args,
1129 std::vector<Tensor>* rets,
1130 DoneCallback done) {
1131 if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
1132 done(errors::Cancelled("Function was cancelled before it was started"));
1133 return;
1134 }
1135 Options run_opts = opts;
1136 if (opts.create_rendezvous) {
1137 auto* rendezvous = new RefCountedIntraProcessRendezvous(device_mgr_);
1138 run_opts.rendezvous = rendezvous;
1139 run_opts.create_rendezvous = false;
1140 done = [done = std::move(done), rendezvous](const Status& status) mutable {
1141 rendezvous->Unref();
1142 done(status);
1143 };
1144 }
1145
1146 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
1147 if (local_handle == kInvalidLocalHandle) {
1148 parent_->Run(run_opts, handle, args, rets, done);
1149 return;
1150 }
1151
1152 if (run_opts.runner == nullptr) {
1153 run_opts.runner = &default_runner_;
1154 }
1155 DCHECK(run_opts.runner != nullptr);
1156
1157 Item* item = nullptr;
1158 Status s = GetOrCreateItem(local_handle, &item);
1159 if (!s.ok()) {
1160 done(s);
1161 return;
1162 }
1163
1164 if (run_opts.remote_execution) {
1165 // NOTE(mrry): `RunRemote()` will set `exec_args->call_frame` for us.
1166 RunRemote(run_opts, handle, args, rets, item, std::move(done));
1167 return;
1168 }
1169
1170 const FunctionBody* fbody = GetFunctionBody(handle);
1171 FunctionCallFrame* frame =
1172 new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
1173 s = frame->SetArgs(args);
1174 if (!s.ok()) {
1175 delete frame;
1176 done(s);
1177 return;
1178 }
1179
1180 profiler::TraceMeProducer activity(
1181 // To TraceMeConsumers in ExecutorState::Process/Finish.
1182 [&opts] {
1183 return profiler::TraceMeEncode("FunctionRun",
1184 {{"id", opts.step_id}, {"_r", 1}});
1185 },
1186 profiler::ContextType::kTfExecutor, opts.step_id,
1187 profiler::TraceMeLevel::kInfo);
1188
1189 Executor::Args exec_args;
1190 ExecutorArgsFromOptions(run_opts, frame, &exec_args);
1191
1192 bool allow_dead_tensors = run_opts.allow_dead_tensors;
1193 item->exec->RunAsync(
1194 // Executor args
1195 exec_args,
1196 // Done callback.
1197 [frame, rets, done, allow_dead_tensors](const Status& status) {
1198 Status s = status;
1199 if (s.ok()) {
1200 s = frame->ConsumeRetvals(rets, allow_dead_tensors);
1201 }
1202 delete frame;
1203 done(s);
1204 });
1205}
1206
1207void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
1208 CallFrameInterface* frame,
1209 DoneCallback done) {
1210 if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
1211 done(errors::Cancelled(""));
1212 return;
1213 }
1214
1215 Options run_opts = opts;
1216 if (opts.create_rendezvous) {
1217 auto* rendezvous = new RefCountedIntraProcessRendezvous(device_mgr_);
1218 run_opts.rendezvous = rendezvous;
1219 run_opts.create_rendezvous = false;
1220 done = [done = std::move(done), rendezvous](const Status& status) mutable {
1221 rendezvous->Unref();
1222 done(status);
1223 };
1224 }
1225
1226 LocalHandle local_handle = parent_->GetHandleOnDevice(
1227 device_name_, handle, /*include_multi_device=*/true);
1228 if (local_handle == kInvalidLocalHandle) {
1229 parent_->Run(run_opts, handle, frame, done);
1230 return;
1231 }
1232
1233 if (opts.remote_execution) {
1234 // NOTE(mrry): This bit is only set for a local function when `parent_`
1235 // calls back into this class, and the current implementation of
1236 // `ProcessFunctionLibraryRuntime` currently always uses the vector-based
1237 // `args`/`rets` interface.
1238 done(errors::Unimplemented("Remote calling with CallFrameInterface"));
1239 return;
1240 }
1241
1242 Item* item = nullptr;
1243 Status s = GetOrCreateItem(local_handle, &item);
1244 if (!s.ok()) {
1245 done(s);
1246 return;
1247 }
1248 if (run_opts.runner == nullptr) {
1249 run_opts.runner = &default_runner_;
1250 }
1251 DCHECK(run_opts.runner != nullptr);
1252
1253 profiler::TraceMeProducer activity(
1254 // To TraceMeConsumers in ExecutorState::Process/Finish.
1255 [&opts] {
1256 return profiler::TraceMeEncode("FunctionRun",
1257 {{"id", opts.step_id}, {"_r", 1}});
1258 },
1259 profiler::ContextType::kTfExecutor, opts.step_id,
1260 profiler::TraceMeLevel::kInfo);
1261
1262 Executor::Args exec_args;
1263 ExecutorArgsFromOptions(run_opts, frame, &exec_args);
1264 item->exec->RunAsync(exec_args, std::move(done));
1265}
1266
1267Status FunctionLibraryRuntimeImpl::PrepareRunSync(
1268 Handle handle, Options* run_opts, Item** out_item,
1269 std::unique_ptr<PrivateIntraProcessRendezvous>* out_rendezvous) {
1270 if (run_opts->cancellation_manager &&
1271 run_opts->cancellation_manager->IsCancelled()) {
1272 return errors::Cancelled("");
1273 }
1274
1275 if (run_opts->remote_execution) {
1276 // NOTE(mrry): This bit is only set for a local function when `parent_`
1277 // calls back into this class, and the current implementation of
1278 // `ProcessFunctionLibraryRuntime` currently always uses the asynchronous
1279 // Run() method.
1280 return errors::Unimplemented("Remote calling with RunSync()");
1281 }
1282
1283 if (run_opts->create_rendezvous) {
1284 *out_rendezvous =
1285 std::make_unique<PrivateIntraProcessRendezvous>(device_mgr_);
1286 run_opts->rendezvous = out_rendezvous->get();
1287 run_opts->create_rendezvous = false;
1288 }
1289
1290 LocalHandle local_handle = parent_->GetHandleOnDevice(
1291 device_name_, handle, /*include_multi_device=*/true);
1292 if (local_handle == kInvalidLocalHandle) {
1293 *out_item = nullptr;
1294 return OkStatus();
1295 }
1296
1297 TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, out_item));
1298
1299 if (run_opts->runner == nullptr) {
1300 run_opts->runner = &default_runner_;
1301 }
1302 DCHECK(run_opts->runner != nullptr);
1303
1304 return OkStatus();
1305}
1306
1307Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
1308 gtl::ArraySlice<Tensor> args,
1309 std::vector<Tensor>* rets) {
1310 Item* item = nullptr;
1311 std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous;
1312 TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous));
1313 if (item == nullptr) {
1314 return parent_->RunSync(opts, handle, args, rets);
1315 }
1316
1317 Executor::Args exec_args;
1318 const FunctionBody* fbody = GetFunctionBody(handle);
1319 FunctionCallFrame frame(fbody->arg_types, fbody->ret_types);
1320 TF_RETURN_IF_ERROR(frame.SetArgs(args));
1321 ExecutorArgsFromOptions(opts, &frame, &exec_args);
1322
1323 TF_RETURN_IF_ERROR(item->exec->Run(exec_args));
1324 return frame.ConsumeRetvals(rets, opts.allow_dead_tensors);
1325}
1326
1327Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
1328 CallFrameInterface* call_frame) {
1329 Item* item = nullptr;
1330 std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous;
1331 TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous));
1332 if (item == nullptr) {
1333 return parent_->RunSync(opts, handle, call_frame);
1334 }
1335
1336 Executor::Args exec_args;
1337 ExecutorArgsFromOptions(opts, call_frame, &exec_args);
1338 return item->exec->Run(exec_args);
1339}
1340
1341bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) const {
1342 const OpDef* op_def;
1343 const Status s = base_lib_def_->LookUpOpDef(func, &op_def);
1344 return s.ok() && op_def->is_stateful();
1345}
1346
1347string FunctionLibraryRuntimeImpl::DebugString(Handle handle) {
1348 Item* item = nullptr;
1349 LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
1350 Status s = GetOrCreateItem(local_handle, &item);
1351 if (s.ok()) {
1352 if (item->graph) {
1353 return tensorflow::DebugString(item->graph.get());
1354 } else {
1355 return tensorflow::DebugString(item->func_graph->graph);
1356 }
1357 } else {
1358 return s.ToString();
1359 }
1360}
1361
1362Status FunctionLibraryRuntimeImpl::Clone(
1363 std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
1364 std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
1365 FunctionLibraryRuntime** out_flr, bool skip_flib_def) {
1366 TF_RETURN_IF_ERROR(parent_->Clone(env_, graph_def_version_,
1367 optimizer_.options(), out_lib_def, out_pflr,
1368 skip_flib_def));
1369 *out_flr = (*out_pflr)->GetFLR(device_->name());
1370 if (*out_flr != nullptr) {
1371 return OkStatus();
1372 } else {
1373 return errors::Internal("Cloning FunctionLibraryRuntime failed.");
1374 }
1375}
1376
1377namespace {
1378
1379struct CustomCreatorSingleton {
1380 mutex mu;
1381 std::unique_ptr<CustomKernelCreator> custom_creator = nullptr;
1382
1383 void Set(CustomKernelCreator* cb) {
1384 mutex_lock l(mu);
1385 custom_creator.reset(cb);
1386 }
1387
1388 CustomKernelCreator* Get() {
1389 mutex_lock l(mu);
1390 return custom_creator.get();
1391 }
1392};
1393
1394CustomCreatorSingleton* GetCustomCreatorSingleton() {
1395 static CustomCreatorSingleton* ccs = new CustomCreatorSingleton;
1396 return ccs;
1397}
1398
1399} // namespace
1400
1401const CustomKernelCreator* GetDefaultCustomKernelCreator() {
1402 return GetCustomCreatorSingleton()->Get();
1403}
1404
1405void RegisterDefaultCustomKernelCreator(CustomKernelCreator* c) {
1406 GetCustomCreatorSingleton()->Set(c);
1407}
1408
1409std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
1410 const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
1411 Device* device, int graph_def_version,
1412 const FunctionLibraryDefinition* lib_def, thread::ThreadPool* thread_pool,
1413 const OptimizerOptions& optimizer_options,
1414 const SessionMetadata* session_metadata,
1415 ProcessFunctionLibraryRuntime* parent) {
1416 return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl(
1417 device_mgr, env, config, device, graph_def_version, lib_def, thread_pool,
1418 optimizer_options, session_metadata, parent));
1419}
1420
1421class SymbolicGradientHelper {
1422 public:
1423 explicit SymbolicGradientHelper(const FunctionBody& f) : fbody_(&f) {}
1424 ~SymbolicGradientHelper() = default;
1425
1426 std::unique_ptr<FunctionBody> Compute();
1427
1428 private:
1429 const FunctionBody* fbody_;
1430
1431 // Makes a copy of fbody_ in gbody.
1432 void Copy(FunctionBody* gbody);
1433
1434 TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientHelper);
1435};
1436
1437void SymbolicGradientHelper::Copy(FunctionBody* gbody) {
1438 const Graph& src = *(fbody_->graph);
1439 gbody->graph = new Graph(src.op_registry());
1440 Graph* dst = gbody->graph;
1441
1442 std::vector<Node*> node_map(src.num_node_ids());
1443
1444 // Copy just the fdef attributes (copy '_noinline' and other similar flags to
1445 // the gradient function body).
1446 *(gbody->fdef.mutable_attr()) = fbody_->fdef.attr();
1447
1448 // Copy the nodes.
1449 node_map[src.source_node()->id()] = dst->source_node();
1450 node_map[src.sink_node()->id()] = dst->sink_node();
1451 for (Node* n : src.op_nodes()) {
1452 node_map[n->id()] = dst->CopyNode(n);
1453 }
1454
1455 // Copy the edges.
1456 for (const Edge* e : src.edges()) {
1457 Node* src_copy = node_map[e->src()->id()];
1458 Node* dst_copy = node_map[e->dst()->id()];
1459 dst->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
1460 }
1461
1462 // Save inputs in copied graph.
1463 CHECK_EQ(fbody_->arg_types.size(), fbody_->arg_nodes.size());
1464 gbody->arg_types = fbody_->arg_types;
1465 for (std::size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
1466 gbody->arg_nodes.push_back(node_map[fbody_->arg_nodes[i]->id()]);
1467 }
1468
1469 // Save outputs in copied graph.
1470 CHECK_EQ(fbody_->ret_types.size(), fbody_->ret_nodes.size());
1471 gbody->ret_types = fbody_->ret_types;
1472 for (std::size_t i = 0; i < fbody_->ret_nodes.size(); ++i) {
1473 gbody->ret_nodes.push_back(node_map[fbody_->ret_nodes[i]->id()]);
1474 }
1475}
1476
1477std::unique_ptr<FunctionBody> SymbolicGradientHelper::Compute() {
1478 FunctionBody* gbody = new FunctionBody;
1479 Copy(gbody); // copy fbody_ into gbody.
1480
1481 Graph* g = gbody->graph;
1482
1483 const int num_y = static_cast<int>(gbody->ret_nodes.size());
1484
1485 // Populate 'y_node_outputs_' with node function body outputs.
1486 // Populate 'y_grad_nodes' with initial gradient nodes for each return node
1487 // of the original function body (these will be 'arg' nodes in the function
1488 // gradient body).
1489 std::vector<NodeOut> y_node_outputs;
1490 y_node_outputs.reserve(num_y);
1491 std::vector<NodeOut> y_grad_node_outputs;
1492 y_grad_node_outputs.reserve(num_y);
1493 for (int i = 0; i < num_y; ++i) {
1494 Node* y = gbody->ret_nodes[i];
1495 y_node_outputs.push_back({y, 0});
1496 DCHECK_EQ(y->type_string(), kRetOp);
1497 const DataType dtype = y->input_type(0);
1498 const int index = static_cast<int>(gbody->arg_nodes.size());
1499 Node* dy = AddArg(g, dtype, index);
1500 gbody->arg_types.push_back(dtype);
1501 gbody->arg_nodes.push_back(dy);
1502 y_grad_node_outputs.push_back({dy, 0});
1503 }
1504
1505 // Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs').
1506 const size_t num_x = fbody_->arg_nodes.size();
1507 std::vector<NodeOut> x_node_outputs;
1508 x_node_outputs.reserve(num_x);
1509 for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
1510 x_node_outputs.push_back({gbody->arg_nodes[i], 0});
1511 }
1512
1513 // Call AddSymbolicGradients which will add nodes to graph 'g' that
1514 // compute the function gradient (adding an entry in 'x_grad_node_outputs'
1515 // for each node in 'x_node_outputs').
1516 std::vector<NodeOut> x_grad_node_outputs;
1517 TF_CHECK_OK(AddSymbolicGradients(y_node_outputs, x_node_outputs,
1518 y_grad_node_outputs, &x_grad_node_outputs,
1519 g));
1520
1521 // Remove the old return nodes from the function body.
1522 for (Node* n : gbody->ret_nodes) {
1523 g->RemoveNode(n);
1524 }
1525 gbody->ret_types = fbody_->arg_types;
1526 // TODO(apassos): use the right dtype for gradients of resource variables
1527 for (int i = 0; i < gbody->ret_types.size(); ++i) {
1528 if (gbody->ret_types[i] == DT_RESOURCE) {
1529 gbody->ret_types[i] = DT_FLOAT;
1530 }
1531 }
1532 gbody->ret_nodes.clear();
1533 // Add new return nodes to the function gradient body for each node
1534 // in 'x_grad_nodes'.
1535 const int arg_types_size = static_cast<int>(fbody_->arg_types.size());
1536 for (int i = 0; i < arg_types_size; ++i) {
1537 Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index};
1538 Node* ret = AddRet(g, grad, i);
1539 gbody->ret_nodes.push_back(ret);
1540 }
1541
1542 return std::unique_ptr<FunctionBody>(gbody);
1543}
1544
1545std::unique_ptr<FunctionBody> SymbolicGradient(const FunctionBody& f) {
1546 return SymbolicGradientHelper(f).Compute();
1547}
1548
1549} // end namespace tensorflow
1550