1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/function.h" |
17 | |
18 | #include <deque> |
19 | #include <vector> |
20 | |
21 | #include "absl/algorithm/container.h" |
22 | #include "absl/memory/memory.h" |
23 | #include "absl/strings/str_cat.h" |
24 | #include "tensorflow/core/common_runtime/device.h" |
25 | #include "tensorflow/core/common_runtime/executor.h" |
26 | #include "tensorflow/core/common_runtime/executor_factory.h" |
27 | #include "tensorflow/core/common_runtime/gradients.h" |
28 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
29 | #include "tensorflow/core/common_runtime/graph_optimizer.h" |
30 | #include "tensorflow/core/common_runtime/inline_function_utils.h" |
31 | #include "tensorflow/core/common_runtime/memory_types.h" |
32 | #include "tensorflow/core/common_runtime/process_function_library_runtime.h" |
33 | #include "tensorflow/core/common_runtime/rendezvous_mgr.h" |
34 | #include "tensorflow/core/common_runtime/single_threaded_executor.h" |
35 | #include "tensorflow/core/framework/collective.h" |
36 | #include "tensorflow/core/framework/function.h" |
37 | #include "tensorflow/core/framework/function_handle_cache.h" |
38 | #include "tensorflow/core/framework/metrics.h" |
39 | #include "tensorflow/core/framework/node_def.pb.h" |
40 | #include "tensorflow/core/framework/node_def_util.h" |
41 | #include "tensorflow/core/framework/op.h" |
42 | #include "tensorflow/core/framework/op_kernel.h" |
43 | #include "tensorflow/core/framework/versions.pb.h" |
44 | #include "tensorflow/core/graph/algorithm.h" |
45 | #include "tensorflow/core/graph/control_flow.h" |
46 | #include "tensorflow/core/graph/node_builder.h" |
47 | #include "tensorflow/core/graph/optimizer_cse.h" |
48 | #include "tensorflow/core/lib/core/threadpool.h" |
49 | #include "tensorflow/core/lib/gtl/map_util.h" |
50 | #include "tensorflow/core/platform/macros.h" |
51 | #include "tensorflow/core/platform/str_util.h" |
52 | #include "tensorflow/core/profiler/lib/connected_traceme.h" |
53 | #include "tensorflow/core/profiler/lib/traceme.h" |
54 | #include "tensorflow/core/protobuf/config.pb.h" |
55 | |
56 | // See core/kernels/function_ops.cc for related kernels. |
57 | |
58 | namespace tensorflow { |
59 | |
60 | // A few string constant used throughout this module. |
61 | static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp; |
62 | static constexpr const char* const kDeviceArgOp = |
63 | FunctionLibraryDefinition::kDeviceArgOp; |
64 | static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp; |
65 | static constexpr const char* const kDeviceRetOp = |
66 | FunctionLibraryDefinition::kDeviceRetOp; |
67 | static constexpr const char* const kGradientOp = |
68 | FunctionLibraryDefinition::kGradientOp; |
69 | static constexpr const char* const kNodeLabel = "Func" ; |
70 | static constexpr const char* const kFuncAttr = |
71 | FunctionLibraryDefinition::kFuncAttr; |
72 | |
73 | // Represents the index-th output of a node. |
74 | struct Endpoint { |
75 | Node* node; |
76 | int index; |
77 | |
78 | // Returns the string name represents this endpoint. |
79 | string name() const { |
80 | if (index == 0) { |
81 | return node->name(); |
82 | } else { |
83 | return strings::StrCat(node->name(), ":" , index); |
84 | } |
85 | } |
86 | |
87 | DataType dtype() const { return node->output_type(index); } |
88 | }; |
89 | |
90 | struct EndpointHash { |
91 | uint64 operator()(const Endpoint& x) const { |
92 | return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*), |
93 | x.index); |
94 | } |
95 | }; |
96 | |
97 | struct EndpointEq { |
98 | bool operator()(const Endpoint& x, const Endpoint& y) const { |
99 | return (x.node == y.node) && (x.index == y.index); |
100 | } |
101 | }; |
102 | |
103 | // The following Add* routines are used to add a few graph nodes while |
104 | // functions are transformed. |
105 | static Node* AddArg(Graph* g, DataType dtype, int index) { |
106 | DCHECK_LT(0, dtype); |
107 | DCHECK_LT(dtype, DT_FLOAT_REF); |
108 | NodeDef ndef; |
109 | ndef.set_name(g->NewName(kNodeLabel)); |
110 | ndef.set_op(kArgOp); |
111 | AddNodeAttr("T" , dtype, &ndef); |
112 | AddNodeAttr("index" , index, &ndef); |
113 | Status s; |
114 | Node* ret = g->AddNode(ndef, &s); |
115 | TF_CHECK_OK(s); |
116 | return ret; |
117 | } |
118 | |
119 | static Node* AddRet(Graph* g, Endpoint input, int index) { |
120 | DCHECK_LT(0, input.dtype()); |
121 | DCHECK_LT(input.dtype(), DT_FLOAT_REF); |
122 | NodeDef ndef; |
123 | ndef.set_name(g->NewName(kNodeLabel)); |
124 | ndef.set_op(kRetOp); |
125 | ndef.add_input(input.name()); |
126 | AddNodeAttr("T" , input.dtype(), &ndef); |
127 | AddNodeAttr("index" , index, &ndef); |
128 | Status s; |
129 | Node* ret = g->AddNode(ndef, &s); |
130 | TF_CHECK_OK(s); |
131 | g->AddEdge(input.node, input.index, ret, 0); |
132 | return ret; |
133 | } |
134 | |
135 | // FunctionLibraryRuntime implementation that forwards all the function calls to |
136 | // the base runtime implementation, and only overrides FunctionLibraryDefinition |
137 | // in calls to Instantiate (if caller doesn't provide the |
138 | // InstantiateOptions::lib_def option). |
139 | // |
140 | // When the function library runtime (FunctionLibraryRuntimeImpl specifically) |
141 | // instantiates a function into a Graph object, it also creates an Executor for |
142 | // it. That executor has a pointer to the function library runtime instance, |
143 | // that is used to instantiate all nested function calls. |
144 | // |
145 | // The function library definition used to instantiate the function must be |
146 | // preserved in the Executor's function library runtime. |
147 | // |
148 | // IMPORTANT: This runtime is intended for use only in executors created for |
149 | // functions instantiated into a graph in FunctionLibraryRuntimeImpl. |
150 | class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime { |
151 | public: |
152 | FunctionLibraryRuntimeOverlay(FunctionLibraryRuntime* base_flr, |
153 | const FunctionLibraryDefinition* lib_def) |
154 | : base_flr_(base_flr), lib_def_(lib_def) {} |
155 | ~FunctionLibraryRuntimeOverlay() override; |
156 | |
157 | Status Instantiate(const string& function_name, AttrSlice attrs, |
158 | const InstantiateOptions& options, |
159 | Handle* handle) override; |
160 | |
161 | Status ReleaseHandle(Handle handle) override; |
162 | |
163 | const FunctionBody* GetFunctionBody(Handle h) override; |
164 | |
165 | Status GetRetTypes(Handle h, DataTypeVector* ret_types) override; |
166 | |
167 | void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args, |
168 | std::vector<Tensor>* rets, DoneCallback done) override; |
169 | |
170 | void Run(const Options& opts, Handle handle, CallFrameInterface* call_frame, |
171 | DoneCallback done) override; |
172 | |
173 | Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args, |
174 | std::vector<Tensor>* rets) override; |
175 | |
176 | Status RunSync(Options opts, Handle handle, |
177 | CallFrameInterface* frame) override; |
178 | |
179 | Status CreateKernel(const std::shared_ptr<const NodeProperties>& props, |
180 | OpKernel** kernel) override; |
181 | |
182 | bool IsStateful(const string& function_name) const override; |
183 | |
184 | const FunctionLibraryDefinition* GetFunctionLibraryDefinition() |
185 | const override; |
186 | |
187 | Env* env() override; |
188 | const ConfigProto* const config_proto() override; |
189 | Device* device() override; |
190 | const Device* device() const override; |
191 | std::function<void(std::function<void()>)>* runner() override; |
192 | const DeviceMgr* device_mgr() const override; |
193 | |
194 | string DebugString(Handle handle) override; |
195 | int graph_def_version() const override; |
196 | |
197 | Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def, |
198 | std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr, |
199 | FunctionLibraryRuntime** out_flr, |
200 | bool skip_flib_def = false) override; |
201 | |
202 | private: |
203 | FunctionLibraryRuntime* base_flr_; // not owned |
204 | const FunctionLibraryDefinition* lib_def_; // not owned |
205 | }; |
206 | |
207 | FunctionLibraryRuntimeOverlay::~FunctionLibraryRuntimeOverlay() = default; |
208 | |
209 | Status FunctionLibraryRuntimeOverlay::Instantiate( |
210 | const string& function_name, AttrSlice attrs, |
211 | const InstantiateOptions& options, Handle* handle) { |
212 | // We automatically set the `lib_def` option for all instantiations, if the |
213 | // caller doesn't set this option explicitly. |
214 | if (!options.lib_def && lib_def_) { |
215 | InstantiateOptions options_copy = options; |
216 | options_copy.lib_def = lib_def_; |
217 | return base_flr_->Instantiate(function_name, attrs, options_copy, handle); |
218 | } else { |
219 | return base_flr_->Instantiate(function_name, attrs, options, handle); |
220 | } |
221 | } |
222 | |
223 | Status FunctionLibraryRuntimeOverlay::ReleaseHandle(Handle handle) { |
224 | return base_flr_->ReleaseHandle(handle); |
225 | } |
226 | |
227 | const FunctionBody* FunctionLibraryRuntimeOverlay::GetFunctionBody(Handle h) { |
228 | return base_flr_->GetFunctionBody(h); |
229 | } |
230 | |
231 | Status FunctionLibraryRuntimeOverlay::GetRetTypes(Handle h, |
232 | DataTypeVector* ret_types) { |
233 | return base_flr_->GetRetTypes(h, ret_types); |
234 | } |
235 | |
236 | void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle, |
237 | gtl::ArraySlice<Tensor> args, |
238 | std::vector<Tensor>* rets, |
239 | DoneCallback done) { |
240 | base_flr_->Run(opts, handle, args, rets, std::move(done)); |
241 | } |
242 | |
243 | void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle, |
244 | CallFrameInterface* call_frame, |
245 | DoneCallback done) { |
246 | base_flr_->Run(opts, handle, call_frame, std::move(done)); |
247 | } |
248 | |
249 | Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle, |
250 | gtl::ArraySlice<Tensor> args, |
251 | std::vector<Tensor>* rets) { |
252 | return base_flr_->RunSync(std::move(opts), handle, args, rets); |
253 | } |
254 | |
255 | Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle, |
256 | CallFrameInterface* call_frame) { |
257 | return base_flr_->RunSync(std::move(opts), handle, call_frame); |
258 | } |
259 | |
260 | Status FunctionLibraryRuntimeOverlay::CreateKernel( |
261 | const std::shared_ptr<const NodeProperties>&, OpKernel**) { |
262 | // We don't have access to base_lib_def_ in base function library runtime (aka |
263 | // FunctionLibraryRuntimeImpl), so to make sure we do not create a kernel with |
264 | // the wrong lib_def we just disable creation of new kernels through overlays. |
265 | // |
266 | // When we call Instantiate from the base runtime with the lib_def option, |
267 | // the base runtime implementation is responsible for correctly passing it |
268 | // through to all kernel constructions. |
269 | return errors::Internal( |
270 | "Overlay function library runtime doesn't support kernel creation." ); |
271 | } |
272 | |
273 | bool FunctionLibraryRuntimeOverlay::IsStateful( |
274 | const string& function_name) const { |
275 | // Important: we do not forward lookup to the base FLR. |
276 | const OpDef* op_def; |
277 | const Status s = lib_def_->LookUpOpDef(function_name, &op_def); |
278 | return s.ok() && op_def->is_stateful(); |
279 | } |
280 | |
281 | Env* FunctionLibraryRuntimeOverlay::env() { return base_flr_->env(); } |
282 | |
283 | const ConfigProto* const FunctionLibraryRuntimeOverlay::config_proto() { |
284 | return base_flr_->config_proto(); |
285 | } |
286 | |
287 | Device* FunctionLibraryRuntimeOverlay::device() { return base_flr_->device(); } |
288 | |
289 | const Device* FunctionLibraryRuntimeOverlay::device() const { |
290 | return base_flr_->device(); |
291 | } |
292 | |
293 | std::function<void(std::function<void()>)>* |
294 | FunctionLibraryRuntimeOverlay::runner() { |
295 | return base_flr_->runner(); |
296 | } |
297 | |
298 | const DeviceMgr* FunctionLibraryRuntimeOverlay::device_mgr() const { |
299 | return base_flr_->device_mgr(); |
300 | } |
301 | |
302 | const FunctionLibraryDefinition* |
303 | FunctionLibraryRuntimeOverlay::GetFunctionLibraryDefinition() const { |
304 | return lib_def_ ? lib_def_ : base_flr_->GetFunctionLibraryDefinition(); |
305 | } |
306 | |
307 | string FunctionLibraryRuntimeOverlay::DebugString(Handle handle) { |
308 | return base_flr_->DebugString(handle); |
309 | } |
310 | |
311 | int FunctionLibraryRuntimeOverlay::graph_def_version() const { |
312 | return base_flr_->graph_def_version(); |
313 | } |
314 | |
315 | Status FunctionLibraryRuntimeOverlay::Clone( |
316 | std::unique_ptr<FunctionLibraryDefinition>* out_lib_def, |
317 | std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr, |
318 | FunctionLibraryRuntime** out_flr, bool skip_flib_def) { |
319 | // NOTE(ezhulenev): The cloned FunctionLibraryRuntime will be missing the |
320 | // FunctionLibraryDefinition override, but that's ok because we anyway do not |
321 | // copy / clone instantiated items from the base FLR. |
322 | return base_flr_->Clone(out_lib_def, out_pflr, out_flr, skip_flib_def); |
323 | } |
324 | |
325 | class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { |
326 | public: |
327 | FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, |
328 | const ConfigProto* config, Device* device, |
329 | int graph_def_version, |
330 | const FunctionLibraryDefinition* lib_def, |
331 | thread::ThreadPool* default_thread_pool, |
332 | const OptimizerOptions& optimizer_options, |
333 | const SessionMetadata* session_metadata, |
334 | ProcessFunctionLibraryRuntime* parent); |
335 | |
336 | ~FunctionLibraryRuntimeImpl() override; |
337 | |
338 | Status Instantiate(const string& function_name, AttrSlice attrs, |
339 | const InstantiateOptions& options, |
340 | Handle* handle) override; |
341 | |
342 | Status ReleaseHandle(Handle handle) override; |
343 | |
344 | const FunctionBody* GetFunctionBody(Handle handle) override; |
345 | |
346 | Status GetRetTypes(Handle handle, DataTypeVector* ret_types) override; |
347 | |
348 | Status CreateKernel(const std::shared_ptr<const NodeProperties>& props, |
349 | OpKernel** kernel) override; |
350 | |
351 | void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args, |
352 | std::vector<Tensor>* rets, DoneCallback done) override; |
353 | void Run(const Options& opts, Handle handle, CallFrameInterface* frame, |
354 | DoneCallback done) override; |
355 | Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args, |
356 | std::vector<Tensor>* rets) override; |
357 | Status RunSync(Options opts, Handle handle, |
358 | CallFrameInterface* call_frame) override; |
359 | |
360 | bool IsStateful(const string& function) const override; |
361 | |
362 | const FunctionLibraryDefinition* GetFunctionLibraryDefinition() |
363 | const override { |
364 | return base_lib_def_; |
365 | } |
366 | |
367 | Device* device() override { return device_; } |
368 | const Device* device() const override { return device_; } |
369 | |
370 | std::function<void(std::function<void()>)>* runner() override { |
371 | return &default_runner_; |
372 | } |
373 | |
374 | const DeviceMgr* device_mgr() const override { return device_mgr_; } |
375 | Env* env() override { return env_; } |
376 | const ConfigProto* const config_proto() override { return config_; } |
377 | int graph_def_version() const override { return graph_def_version_; } |
378 | |
379 | string DebugString(Handle h) override; |
380 | |
381 | Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def, |
382 | std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr, |
383 | FunctionLibraryRuntime** out_flr, |
384 | bool skip_flib_def = false) override; |
385 | |
386 | private: |
387 | typedef FunctionLibraryRuntimeImpl ME; |
388 | |
389 | const DeviceMgr* const device_mgr_; |
390 | Device* const device_; |
391 | Env* const env_; |
392 | const ConfigProto* const config_; |
393 | const int graph_def_version_; |
394 | const FunctionLibraryDefinition* const base_lib_def_; |
395 | GraphOptimizer optimizer_; |
396 | const SessionMetadata* const session_metadata_; |
397 | Executor::Args::Runner default_runner_; |
398 | const string device_name_; |
399 | |
400 | std::function<Status(const string&, const OpDef**)> get_func_sig_; |
401 | std::function<Status(const std::shared_ptr<const NodeProperties>&, |
402 | OpKernel**)> |
403 | create_kernel_; |
404 | |
405 | mutable mutex mu_; |
406 | |
407 | int next_handle_ TF_GUARDED_BY(mu_); |
408 | |
409 | // The instantiated and transformed function is encoded as a Graph |
410 | // object, and an executor is created for the graph. |
411 | struct Item { |
412 | uint64 instantiation_counter = 0; |
413 | std::unique_ptr<const Graph> graph = nullptr; |
414 | const FunctionLibraryDefinition* lib_def = nullptr; // Not owned. |
415 | FunctionBody* func_graph = nullptr; |
416 | Executor* exec = nullptr; |
417 | FunctionLibraryRuntimeOverlay* overlay_flr = nullptr; |
418 | string executor_type; |
419 | bool allow_small_function_optimizations = false; |
420 | bool allow_control_flow_sync_execution = false; |
421 | |
422 | ~Item() { |
423 | delete this->func_graph; |
424 | delete this->exec; |
425 | delete this->overlay_flr; |
426 | } |
427 | }; |
428 | std::unique_ptr<absl::flat_hash_map<Handle, std::unique_ptr<Item>>> items_ |
429 | TF_GUARDED_BY(mu_); |
430 | std::unique_ptr<FunctionHandleCache> function_handle_cache_; |
431 | ProcessFunctionLibraryRuntime* parent_ = nullptr; // not owned. |
432 | |
433 | // Overloads the CreateKernel method, providing a FunctionLibraryRuntime |
434 | // to use for kernel creation and execution. In particular, this method can |
435 | // accept a FunctionLibraryRuntimeOverlay that overlays a different |
436 | // FunctionLibraryDefinition. |
437 | Status CreateKernel(const std::shared_ptr<const NodeProperties>& props, |
438 | FunctionLibraryRuntime* flr, OpKernel** kernel); |
439 | Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs, |
440 | const FunctionLibraryDefinition* lib_def, |
441 | std::unique_ptr<FunctionBody>* fbody); |
442 | Status CreateItem(Item** item); |
443 | Status GetOrCreateItem(LocalHandle local_handle, Item** item); |
444 | Status InstantiateSymbolicGradient(const NameAttrList& func, |
445 | const FunctionLibraryDefinition* lib_def, |
446 | std::unique_ptr<FunctionBody>* g_body); |
447 | bool IsLocalTarget(const InstantiateOptions& options) const; |
448 | AttrValueMap FixAttrs(const AttrSlice& attrs); |
449 | void RunRemote(const Options& opts, Handle handle, |
450 | gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets, |
451 | Item* item, DoneCallback done); |
452 | |
453 | // TODO(fishx): Avoid using std::unique_ptr for PrivateIntraProcessRendezvous, |
454 | // since it will allocate the object on heap. |
455 | Status PrepareRunSync( |
456 | Handle handle, Options* run_opts, Item** out_item, |
457 | std::unique_ptr<PrivateIntraProcessRendezvous>* out_rendezvous); |
458 | |
459 | void ExecutorArgsFromOptions(const FunctionLibraryRuntime::Options& run_opts, |
460 | CallFrameInterface* frame, |
461 | Executor::Args* exec_args); |
462 | |
463 | TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl); |
464 | }; |
465 | |
466 | FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl( |
467 | const DeviceMgr* dmgr, Env* env, const ConfigProto* config, Device* device, |
468 | int graph_def_version, const FunctionLibraryDefinition* lib_def, |
469 | thread::ThreadPool* default_thread_pool, |
470 | const OptimizerOptions& optimizer_options, |
471 | const SessionMetadata* session_metadata, |
472 | ProcessFunctionLibraryRuntime* parent) |
473 | : device_mgr_(dmgr), |
474 | device_(device), |
475 | env_(env), |
476 | config_(config), |
477 | graph_def_version_(graph_def_version), |
478 | base_lib_def_(lib_def), |
479 | optimizer_(optimizer_options), |
480 | session_metadata_(session_metadata), |
481 | default_runner_(nullptr), |
482 | device_name_(device_ == nullptr |
483 | ? ProcessFunctionLibraryRuntime::kDefaultFLRDevice |
484 | : device_->name()), |
485 | next_handle_(0), |
486 | items_(std::make_unique< |
487 | absl::flat_hash_map<Handle, std::unique_ptr<Item>>>()), |
488 | function_handle_cache_(std::make_unique<FunctionHandleCache>(this)), |
489 | parent_(parent) { |
490 | get_func_sig_ = [this](const string& op, const OpDef** sig) { |
491 | return base_lib_def_->LookUpOpDef(op, sig); |
492 | }; |
493 | create_kernel_ = [this](const std::shared_ptr<const NodeProperties>& props, |
494 | OpKernel** kernel) { |
495 | return CreateKernel(props, kernel); |
496 | }; |
497 | thread::ThreadPool* pool = nullptr; |
498 | if (device_ != nullptr) { |
499 | pool = device_->tensorflow_device_thread_pool(); |
500 | } |
501 | if (pool == nullptr) { |
502 | pool = default_thread_pool; |
503 | } |
504 | if (pool != nullptr) { |
505 | default_runner_ = [pool](Executor::Args::Closure c) { |
506 | pool->Schedule(std::move(c)); |
507 | }; |
508 | } |
509 | } |
510 | |
511 | FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() { |
512 | // Deleting the items_ list will delete all the function handles registered in |
513 | // this object. A function may contains a few sub-functions which have also |
514 | // been registered in this object. Deleting the parent function will call |
515 | // ReleaseHandle in this class again for each of the sub-functions. These |
516 | // circular calls may cause segfault since the items_ may have already been |
517 | // partially deleted when releasing handles of sub-functions. Explicitly |
518 | // release items_ here and check it in ReleaseHandle to avoid this. |
519 | items_.reset(); |
520 | } |
521 | |
522 | // An asynchronous op kernel which executes an instantiated function |
523 | // defined in a library. |
524 | class CallOp : public AsyncOpKernel { |
525 | public: |
526 | CallOp(FunctionLibraryRuntime::Handle handle, OpKernelConstruction* ctx) |
527 | : AsyncOpKernel(ctx), handle_(handle) {} |
528 | |
529 | ~CallOp() override { |
530 | // TODO(iga): Release the cached handle_ |
531 | } |
532 | |
533 | void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { |
534 | FunctionLibraryRuntime* lib = ctx->function_library(); |
535 | OP_REQUIRES_ASYNC(ctx, lib != nullptr, |
536 | errors::Internal("No function library is provided." ), |
537 | done); |
538 | FunctionLibraryRuntime::Options opts; |
539 | opts.rendezvous = ctx->rendezvous(); |
540 | opts.cancellation_manager = ctx->cancellation_manager(); |
541 | opts.step_container = ctx->step_container(); |
542 | opts.stats_collector = ctx->stats_collector(); |
543 | opts.runner = ctx->runner(); |
544 | opts.run_all_kernels_inline = ctx->run_all_kernels_inline(); |
545 | opts.collective_executor = ctx->collective_executor(); |
546 | opts.stack_trace = ctx->stack_trace(); |
547 | std::vector<Tensor> args; |
548 | args.reserve(ctx->num_inputs()); |
549 | for (int i = 0; i < ctx->num_inputs(); ++i) { |
550 | args.push_back(ctx->input(i)); |
551 | } |
552 | std::vector<Tensor>* rets = new std::vector<Tensor>; |
553 | lib->Run(opts, handle_, args, rets, |
554 | [ctx, done, rets](const Status& status) { |
555 | if (!status.ok()) { |
556 | ctx->SetStatus(status); |
557 | } else { |
558 | const int ret_size = static_cast<int>(rets->size()); |
559 | CHECK_EQ(ret_size, ctx->num_outputs()); |
560 | for (int i = 0; i < ret_size; ++i) { |
561 | ctx->set_output(i, (*rets)[i]); |
562 | } |
563 | } |
564 | delete rets; |
565 | done(); |
566 | }); |
567 | } |
568 | |
569 | private: |
570 | FunctionLibraryRuntime::Handle handle_; |
571 | |
572 | TF_DISALLOW_COPY_AND_ASSIGN(CallOp); |
573 | }; |
574 | |
575 | const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) { |
576 | LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h); |
577 | if (local_handle == kInvalidLocalHandle) { |
578 | LOG(ERROR) << "Could not find Handle: " << h |
579 | << " on device: " << device_name_; |
580 | return nullptr; |
581 | } |
582 | |
583 | tf_shared_lock l(mu_); |
584 | auto iter = items_->find(local_handle); |
585 | CHECK(iter != items_->end()); |
586 | return iter->second->func_graph; |
587 | } |
588 | |
589 | Status FunctionLibraryRuntimeImpl::GetRetTypes(Handle h, |
590 | DataTypeVector* ret_types) { |
591 | if (parent_->IsMultiDevice(h)) { |
592 | return parent_->GetRetTypes(h, ret_types); |
593 | } |
594 | LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h); |
595 | if (local_handle == kInvalidLocalHandle) { |
596 | return errors::InvalidArgument("Handle " , h, " not found." ); |
597 | } |
598 | const FunctionBody* fbody = GetFunctionBody(h); |
599 | *ret_types = fbody->ret_types; |
600 | return OkStatus(); |
601 | } |
602 | |
603 | Status FunctionLibraryRuntimeImpl::CreateKernel( |
604 | const std::shared_ptr<const NodeProperties>& props, OpKernel** kernel) { |
605 | return CreateKernel(props, this, kernel); |
606 | } |
607 | |
608 | Status FunctionLibraryRuntimeImpl::CreateKernel( |
609 | const std::shared_ptr<const NodeProperties>& props, |
610 | FunctionLibraryRuntime* flr, OpKernel** kernel) { |
611 | // If a custom kernel creator is given, try that. |
612 | Status s; |
613 | const CustomKernelCreator* custom_kernel_creator = |
614 | GetDefaultCustomKernelCreator(); |
615 | if (custom_kernel_creator && |
616 | custom_kernel_creator->CanCreateKernel(*flr, props)) { |
617 | std::unique_ptr<OpKernel> ret; |
618 | s = custom_kernel_creator->CreateKernel(flr, props, &ret); |
619 | if (s.ok()) { |
620 | *kernel = ret.release(); |
621 | } else { |
622 | VLOG(2) << "Custom creator error: " << s; |
623 | } |
624 | return s; |
625 | } |
626 | |
627 | const FunctionLibraryDefinition* lib_def = |
628 | flr->GetFunctionLibraryDefinition(); |
629 | if (lib_def->Find(props->node_def.op()) == nullptr) { |
630 | // A primitive operation. Creates the registered kernel. |
631 | return CreateNonCachedKernel(device_, flr, props, graph_def_version_, |
632 | kernel); |
633 | } |
634 | |
635 | // Try to instantiate this function for the func/attr. Maybe it's |
636 | // cached already. |
637 | InstantiateOptions options; |
638 | if (lib_def != base_lib_def_) { |
639 | options.lib_def = lib_def; |
640 | } |
641 | Handle handle; |
642 | TF_RETURN_IF_ERROR(Instantiate(props->node_def.op(), |
643 | AttrSlice(&props->node_def.attr()), options, |
644 | &handle)); |
645 | |
646 | const FunctionBody* fbody = GetFunctionBody(handle); |
647 | CHECK_NOTNULL(fbody); |
648 | |
649 | // TODO(zhifengc): For now, we assume int32 and resources are always on host |
650 | // memory and other types are always on device memory. We should do type |
651 | // inference over function body to derive the correct input/output memory |
652 | // types. |
653 | MemoryTypeVector input_memory_types; |
654 | for (const auto& t : fbody->arg_types) { |
655 | input_memory_types.push_back(MTypeFromDType(t)); |
656 | } |
657 | MemoryTypeVector output_memory_types; |
658 | for (const auto& t : fbody->ret_types) { |
659 | output_memory_types.push_back(MTypeFromDType(t)); |
660 | } |
661 | |
662 | // Constructs a CallOp kernel for running the instantiated function. |
663 | auto device_type = DeviceType(device_->attributes().device_type()); |
664 | auto new_props = std::make_shared<NodeProperties>( |
665 | &fbody->fdef.signature(), props->node_def, fbody->arg_types, |
666 | fbody->ret_types); |
667 | OpKernelConstruction construction( |
668 | device_type, device_, device_->GetAllocator(AllocatorAttributes()), flr, |
669 | device_->resource_manager(), props, input_memory_types, |
670 | output_memory_types, graph_def_version_, &s); |
671 | if (s.ok()) { |
672 | *kernel = new CallOp(handle, &construction); |
673 | } |
674 | return s; |
675 | } |
676 | |
677 | Status FunctionLibraryRuntimeImpl::FunctionDefToBody( |
678 | const FunctionDef& fdef, AttrSlice attrs, |
679 | const FunctionLibraryDefinition* lib_def, |
680 | std::unique_ptr<FunctionBody>* fbody) { |
681 | if (lib_def == base_lib_def_) { |
682 | return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig_, fbody); |
683 | } else { |
684 | auto get_func_sig = [lib_def](const string& op, const OpDef** sig) { |
685 | return lib_def->LookUpOpDef(op, sig); |
686 | }; |
687 | return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody); |
688 | } |
689 | } |
690 | |
691 | Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient( |
692 | const NameAttrList& func, const FunctionLibraryDefinition* lib_def, |
693 | std::unique_ptr<FunctionBody>* g_body) { |
694 | const FunctionDef* fdef = lib_def->Find(func.name()); |
695 | if (fdef == nullptr) { |
696 | // f is a primitive op. |
697 | gradient::Creator creator; |
698 | TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator)); |
699 | if (creator == nullptr) { |
700 | return errors::InvalidArgument("No gradient is defined for " , |
701 | func.name()); |
702 | } |
703 | FunctionDef grad_fdef; |
704 | // TODO(josh11b): Should filter out the attrs from func that aren't used |
705 | // by the gradient function. |
706 | TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef)); |
707 | TF_RETURN_IF_ERROR( |
708 | FunctionDefToBody(grad_fdef, AttrSlice(&func.attr()), lib_def, g_body)); |
709 | } else { |
710 | // f is a user-defined function. |
711 | InstantiateOptions options; |
712 | if (lib_def != base_lib_def_) { |
713 | options.lib_def = lib_def; |
714 | } |
715 | Handle f_handle; |
716 | TF_RETURN_IF_ERROR( |
717 | Instantiate(func.name(), AttrSlice(&func.attr()), options, &f_handle)); |
718 | const FunctionBody* f_body = GetFunctionBody(f_handle); |
719 | CHECK_NOTNULL(f_body); |
720 | *g_body = SymbolicGradient(*f_body); |
721 | } |
722 | return OkStatus(); |
723 | } |
724 | |
725 | bool FunctionLibraryRuntimeImpl::IsLocalTarget( |
726 | const InstantiateOptions& options) const { |
727 | if (device_ == nullptr) return true; |
728 | if (options.target.empty()) return true; |
729 | if (options.is_multi_device_function) return false; |
730 | Device* target_device; |
731 | if (!device_mgr_->LookupDevice(options.target, &target_device).ok()) { |
732 | VLOG(1) << "Not instantiating function in FLR because failed to " |
733 | << "find device " << options.target << " in device manager" ; |
734 | return false; |
735 | } |
736 | if (target_device != device_) { |
737 | VLOG(1) << "Not instantiating function in FLR because target device " |
738 | << options.target |
739 | << " is different from FLR's device: " << device_->DebugString(); |
740 | return false; |
741 | } |
742 | return true; |
743 | } |
744 | |
745 | Status FunctionLibraryRuntimeImpl::Instantiate( |
746 | const string& function_name, AttrSlice attrs, |
747 | const InstantiateOptions& options, Handle* handle) { |
748 | if (!IsLocalTarget(options)) { |
749 | return parent_->Instantiate(function_name, attrs, options, handle); |
750 | } |
751 | |
752 | if (options.use_function_cache) { |
753 | InstantiateOptions options_copy(options); |
754 | options_copy.use_function_cache = false; |
755 | return function_handle_cache_->Instantiate(function_name, attrs, |
756 | options_copy, handle); |
757 | } |
758 | |
759 | // Since this is a local target, ensure that the local `device_name_` appears |
760 | // in the canonical key. |
761 | InstantiateOptions options_copy(options); |
762 | options_copy.target = device_name_; |
763 | const string key = Canonicalize(function_name, attrs, options_copy); |
764 | |
765 | { |
766 | mutex_lock l(mu_); |
767 | *handle = parent_->GetHandle(key); |
768 | if (*handle != kInvalidHandle) { |
769 | FunctionLibraryRuntime::LocalHandle handle_on_device = |
770 | parent_->GetHandleOnDevice(device_name_, *handle); |
771 | if (handle_on_device == kInvalidLocalHandle) { |
772 | return errors::Internal("LocalHandle not found for handle " , *handle, |
773 | "." ); |
774 | } |
775 | auto item_handle = items_->find(handle_on_device); |
776 | if (item_handle == items_->end()) { |
777 | return errors::Internal("LocalHandle " , handle_on_device, |
778 | " for handle " , *handle, |
779 | " not found in items." ); |
780 | } |
781 | ++item_handle->second->instantiation_counter; |
782 | return OkStatus(); |
783 | } |
784 | } |
785 | |
786 | const FunctionLibraryDefinition* lib_def = |
787 | options.lib_def ? options.lib_def : base_lib_def_; |
788 | std::unique_ptr<FunctionBody> fbody; |
789 | if (function_name == kGradientOp) { |
790 | const AttrValue* f = attrs.Find(kFuncAttr); |
791 | if (f == nullptr) { |
792 | return errors::InvalidArgument("SymbolicGradient is missing attr: f" ); |
793 | } |
794 | const auto& func = f->func(); |
795 | if (func.name() == kGradientOp) { |
796 | return errors::InvalidArgument("Can't take gradient of SymbolicGradient" ); |
797 | } |
798 | const string grad = lib_def->FindGradient(func.name()); |
799 | if (!grad.empty()) { |
800 | return Instantiate(grad, AttrSlice(&func.attr()), options, handle); |
801 | } |
802 | TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, lib_def, &fbody)); |
803 | } else { |
804 | const FunctionDef* fdef = lib_def->Find(function_name); |
805 | if (fdef == nullptr) { |
806 | return errors::NotFound("Function " , function_name, " is not defined." ); |
807 | } |
808 | TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, lib_def, &fbody)); |
809 | } |
810 | |
811 | LocalHandle local_handle; |
812 | { |
813 | mutex_lock l(mu_); |
814 | *handle = parent_->GetHandle(key); |
815 | if (*handle != kInvalidHandle) { |
816 | local_handle = parent_->GetHandleOnDevice(device_name_, *handle); |
817 | ++(*items_)[local_handle]->instantiation_counter; |
818 | } else { |
819 | *handle = parent_->AddHandle(key, device_name_, next_handle_); |
820 | Item* item = new Item; |
821 | item->func_graph = fbody.release(); |
822 | item->instantiation_counter = 1; |
823 | item->executor_type = ExecutorType(options, attrs); |
824 | item->allow_small_function_optimizations = |
825 | options.allow_small_function_optimizations; |
826 | item->allow_control_flow_sync_execution = |
827 | options.allow_control_flow_sync_execution; |
828 | if (options.lib_def) { |
829 | item->overlay_flr = |
830 | new FunctionLibraryRuntimeOverlay(this, options.lib_def); |
831 | } |
832 | local_handle = next_handle_++; |
833 | items_->emplace(local_handle, std::unique_ptr<Item>(item)); |
834 | } |
835 | } |
836 | |
837 | if (options.create_kernels_eagerly) { |
838 | Item* item; |
839 | TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, &item)); |
840 | } |
841 | |
842 | return OkStatus(); |
843 | } |
844 | |
845 | Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) { |
846 | LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle); |
847 | if (h == kInvalidLocalHandle) { |
848 | return parent_->ReleaseHandle(handle); |
849 | } |
850 | std::unique_ptr<Item> item_to_delete; |
851 | Status parent_status; |
852 | { |
853 | mutex_lock l(mu_); |
854 | // Return directly if all items has already been released. |
855 | if (items_ == nullptr) return OkStatus(); |
856 | |
857 | auto it = items_->find(h); |
858 | if (it == items_->end()) { |
859 | return errors::Internal( |
860 | "Inconsistent FunctionLibraryRuntime. Expected to find an item for " |
861 | "handle " , |
862 | h, " but found none" ); |
863 | } |
864 | std::unique_ptr<Item>& item = it->second; |
865 | --item->instantiation_counter; |
866 | if (item->instantiation_counter == 0) { |
867 | // We don't simply erase h's item because that would trigger |
868 | // item destruction while holding mu_. Item destruction can |
869 | // trigger graph destruction. If the graph contains kernels like |
870 | // CallOp or PartitionCallOp, their destructors will release cached |
871 | // function handles, resulting in deadlock here. |
872 | item_to_delete = std::move(item); |
873 | items_->erase(h); |
874 | parent_status = parent_->RemoveHandle(handle); |
875 | } |
876 | } |
877 | return parent_status; |
878 | } |
879 | |
880 | namespace { |
881 | |
882 | // Removes all stateless nodes that do not contribute to a return |
883 | // value from the function body. Unlike `RemoveDeadNodes()`, which is |
884 | // triggered by `OptimizerOptions.do_function_inlining`, this pass |
885 | // ignores the SINK node, from which (by definition) all nodes are |
886 | // reverse reachable, and preserves all nodes that are reachable from |
887 | // control output nodes. |
888 | // |
889 | // TODO(ezhulenev, skyewm): Function body should not have special treatment of |
890 | // stateful ops, graph should encode nodes that must execute with `control_ret` |
891 | // and `control_output`. |
892 | void PruneFunctionBody(const FunctionDef& fdef, Graph* g) { |
893 | VLOG(2) << "Pruning function body: function_name=" << fdef.signature().name(); |
894 | |
895 | // `control_ret` nodes must be always executed. |
896 | std::unordered_set<StringPiece, StringPieceHasher> control_ret_nodes; |
897 | for (const auto& control_ret : fdef.control_ret()) { |
898 | control_ret_nodes.insert(control_ret.second); |
899 | } |
900 | |
901 | std::unordered_set<const Node*> nodes; |
902 | for (auto n : g->nodes()) { |
903 | // NOTE(mrry): "_Retval" nodes are stateful, and so will be added |
904 | // to the seed set of `nodes`. "_Arg" nodes are also stateful, but we |
905 | // specifically exclude them as seeds, to avoid unconditionally executing |
906 | // unused argument nodes (e.g. in a function like `lambda x, y: y`). |
907 | // TODO(mrry): Investigate whether the `n->IsControlFlow()` test is |
908 | // still needed. It would be preferable to prune entire loops and/or |
909 | // conditionals if they are not used in the graph. |
910 | if (n->IsControlFlow() || |
911 | (n->op_def().is_stateful() && n->type_string() != kArgOp) || |
912 | (control_ret_nodes.find(n->name()) != control_ret_nodes.end())) { |
913 | nodes.insert(n); |
914 | } |
915 | } |
916 | bool changed = PruneForReverseReachability(g, std::move(nodes)); |
917 | if (changed) { |
918 | FixupSourceAndSinkEdges(g); |
919 | } |
920 | } |
921 | |
922 | } // namespace |
923 | |
924 | Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) { |
925 | const FunctionBody* fbody; |
926 | FunctionLibraryRuntime* flr; |
927 | string executor_type; |
928 | { |
929 | tf_shared_lock l(mu_); |
930 | fbody = (*item)->func_graph; |
931 | flr = (*item)->overlay_flr |
932 | ? static_cast<FunctionLibraryRuntime*>((*item)->overlay_flr) |
933 | : static_cast<FunctionLibraryRuntime*>(this); |
934 | executor_type = (*item)->executor_type; |
935 | } |
936 | const FunctionLibraryDefinition* lib_def = |
937 | flr->GetFunctionLibraryDefinition(); |
938 | auto g = std::make_unique<Graph>(lib_def); |
939 | CopyGraph(*fbody->graph, g.get()); |
940 | |
941 | PruneFunctionBody(fbody->fdef, g.get()); |
942 | optimizer_.Optimize(this, env(), device(), &g, GraphOptimizer::Options()); |
943 | TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()), |
944 | device()->name(), g.get())); |
945 | |
946 | // Creates an executor based on the g. This must be done without |
947 | // holding mu_ because create_kernel_ calls back into the library. |
948 | LocalExecutorParams params; |
949 | params.device = device_; |
950 | params.function_library = flr; |
951 | params.allow_control_flow_sync_execution = |
952 | (*item)->allow_control_flow_sync_execution; |
953 | if (flr == this) { |
954 | params.create_kernel = create_kernel_; |
955 | } else { |
956 | params.create_kernel = |
957 | [this, flr](const std::shared_ptr<const NodeProperties>& props, |
958 | OpKernel** kernel) { |
959 | return CreateKernel(props, flr, kernel); |
960 | }; |
961 | } |
962 | params.delete_kernel = [](OpKernel* kernel) { |
963 | DeleteNonCachedKernel(kernel); |
964 | }; |
965 | params.session_metadata = session_metadata_; |
966 | std::unique_ptr<Executor> exec; |
967 | |
968 | // When the instantiation options request small function optimizations, all |
969 | // graphs which are safe for synchronous execution will set this flag to true: |
970 | if ((*item)->allow_small_function_optimizations && executor_type.empty()) { |
971 | executor_type = "SINGLE_THREADED_EXECUTOR" ; |
972 | } |
973 | |
974 | metrics::IncrementTestCounter("flr_executor" , |
975 | (executor_type == "SINGLE_THREADED_EXECUTOR" ) |
976 | ? "single_threaded" |
977 | : "default" ); |
978 | |
979 | TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, *g, &exec)); |
980 | { |
981 | // Guard item since it is already inserted in items_. |
982 | mutex_lock l(mu_); |
983 | if ((*item)->exec == nullptr) { |
984 | (*item)->graph = std::move(g); |
985 | (*item)->exec = exec.release(); |
986 | } |
987 | } |
988 | return OkStatus(); |
989 | } |
990 | |
991 | Status FunctionLibraryRuntimeImpl::GetOrCreateItem(LocalHandle local_handle, |
992 | Item** item) { |
993 | { |
994 | tf_shared_lock l(mu_); |
995 | auto iter = items_->find(local_handle); |
996 | if (iter == items_->end()) { |
997 | return errors::Internal("Local function handle " , local_handle, |
998 | " is not valid. Likely an internal error." ); |
999 | } |
1000 | *item = iter->second.get(); |
1001 | if ((*item)->exec != nullptr) { |
1002 | return OkStatus(); |
1003 | } |
1004 | } |
1005 | // NOTE: We need to call CreateItem out of mu_ because creating an |
1006 | // executor needs to call CreateKernel. |
1007 | return CreateItem(item); |
1008 | } |
1009 | |
1010 | void FunctionLibraryRuntimeImpl::ExecutorArgsFromOptions( |
1011 | const FunctionLibraryRuntime::Options& run_opts, CallFrameInterface* frame, |
1012 | Executor::Args* exec_args) { |
1013 | // Inherit the step_id from the caller. |
1014 | exec_args->step_id = run_opts.step_id; |
1015 | exec_args->rendezvous = run_opts.rendezvous; |
1016 | exec_args->stats_collector = run_opts.stats_collector; |
1017 | exec_args->cancellation_manager = run_opts.cancellation_manager; |
1018 | exec_args->step_container = run_opts.step_container; |
1019 | if (run_opts.runner) { |
1020 | exec_args->runner = *run_opts.runner; |
1021 | } else { |
1022 | exec_args->runner = default_runner_; |
1023 | } |
1024 | exec_args->collective_executor = run_opts.collective_executor; |
1025 | exec_args->call_frame = frame; |
1026 | exec_args->run_all_kernels_inline = run_opts.run_all_kernels_inline; |
1027 | exec_args->user_intra_op_threadpool = run_opts.user_intra_op_threadpool; |
1028 | exec_args->coordination_service_agent = run_opts.coordination_service_agent; |
1029 | exec_args->stack_trace = run_opts.stack_trace; |
1030 | } |
1031 | |
1032 | void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, |
1033 | gtl::ArraySlice<Tensor> args, |
1034 | std::vector<Tensor>* rets, |
1035 | Item* item, DoneCallback done) { |
1036 | string target_device = parent_->GetDeviceName(handle); |
1037 | string source_device = opts.source_device; |
1038 | RendezvousInterface* rendezvous = opts.rendezvous; |
1039 | DeviceContext* device_context; |
1040 | Status s = parent_->GetDeviceContext(target_device, &device_context); |
1041 | if (!s.ok()) { |
1042 | done(s); |
1043 | return; |
1044 | } |
1045 | int64_t src_incarnation, target_incarnation; |
1046 | s = parent_->GetDeviceIncarnation(source_device, &src_incarnation); |
1047 | s.Update(parent_->GetDeviceIncarnation(target_device, &target_incarnation)); |
1048 | if (!s.ok()) { |
1049 | done(s); |
1050 | return; |
1051 | } |
1052 | |
1053 | const FunctionBody* fbody = GetFunctionBody(handle); |
1054 | FunctionCallFrame* frame = |
1055 | new FunctionCallFrame(fbody->arg_types, fbody->ret_types); |
1056 | Executor::Args* exec_args = new Executor::Args; |
1057 | ExecutorArgsFromOptions(opts, frame, exec_args); |
1058 | |
1059 | std::vector<AllocatorAttributes> args_alloc_attrs, rets_alloc_attrs; |
1060 | args_alloc_attrs.reserve(fbody->arg_types.size()); |
1061 | rets_alloc_attrs.reserve(fbody->ret_types.size()); |
1062 | // Note: Functions assume that int32's are always on host memory. |
1063 | for (const auto& arg_type : fbody->arg_types) { |
1064 | AllocatorAttributes arg_alloc_attrs; |
1065 | if (MTypeFromDType(arg_type) == HOST_MEMORY) { |
1066 | arg_alloc_attrs.set_on_host(true); |
1067 | } |
1068 | args_alloc_attrs.push_back(arg_alloc_attrs); |
1069 | } |
1070 | for (const auto& ret_type : fbody->ret_types) { |
1071 | AllocatorAttributes ret_alloc_attrs; |
1072 | if (MTypeFromDType(ret_type) == HOST_MEMORY) { |
1073 | ret_alloc_attrs.set_on_host(true); |
1074 | } |
1075 | rets_alloc_attrs.push_back(ret_alloc_attrs); |
1076 | } |
1077 | |
1078 | bool allow_dead_tensors = opts.allow_dead_tensors; |
1079 | |
1080 | // The ProcFLR sends the arguments to the function from the source_device to |
1081 | // the target_device. So here we receive those arguments. Similarly, when the |
1082 | // computation is done and stored in *rets, we send the return values back |
1083 | // to the source_device (caller) so that the ProcFLR can receive them later. |
1084 | std::vector<Tensor>* remote_args = new std::vector<Tensor>; |
1085 | ProcessFunctionLibraryRuntime::ReceiveTensorsAsync( |
1086 | source_device, target_device, "arg_" , src_incarnation, args.size(), |
1087 | device_context, args_alloc_attrs, rendezvous, remote_args, |
1088 | [frame, remote_args, item, source_device, target_device, |
1089 | target_incarnation, rendezvous, device_context, rets, done, exec_args, |
1090 | rets_alloc_attrs, allow_dead_tensors](const Status& status) { |
1091 | Status s = status; |
1092 | if (s.ok()) { |
1093 | s = frame->SetArgs(*remote_args); |
1094 | } |
1095 | if (!s.ok()) { |
1096 | delete frame; |
1097 | delete remote_args; |
1098 | delete exec_args; |
1099 | done(s); |
1100 | return; |
1101 | } |
1102 | item->exec->RunAsync( |
1103 | *exec_args, |
1104 | [frame, rets, done, source_device, target_device, |
1105 | target_incarnation, rendezvous, device_context, remote_args, |
1106 | rets_alloc_attrs, allow_dead_tensors](const Status& status) { |
1107 | Status s = status; |
1108 | if (s.ok()) { |
1109 | s = frame->ConsumeRetvals(rets, allow_dead_tensors); |
1110 | } |
1111 | delete frame; |
1112 | if (!s.ok()) { |
1113 | delete remote_args; |
1114 | done(s); |
1115 | return; |
1116 | } |
1117 | s = ProcessFunctionLibraryRuntime::SendTensors( |
1118 | target_device, source_device, "ret_" , target_incarnation, |
1119 | *rets, device_context, rets_alloc_attrs, rendezvous); |
1120 | delete remote_args; |
1121 | done(s); |
1122 | }); |
1123 | delete exec_args; |
1124 | }); |
1125 | } |
1126 | |
1127 | void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, |
1128 | gtl::ArraySlice<Tensor> args, |
1129 | std::vector<Tensor>* rets, |
1130 | DoneCallback done) { |
1131 | if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) { |
1132 | done(errors::Cancelled("Function was cancelled before it was started" )); |
1133 | return; |
1134 | } |
1135 | Options run_opts = opts; |
1136 | if (opts.create_rendezvous) { |
1137 | auto* rendezvous = new RefCountedIntraProcessRendezvous(device_mgr_); |
1138 | run_opts.rendezvous = rendezvous; |
1139 | run_opts.create_rendezvous = false; |
1140 | done = [done = std::move(done), rendezvous](const Status& status) mutable { |
1141 | rendezvous->Unref(); |
1142 | done(status); |
1143 | }; |
1144 | } |
1145 | |
1146 | LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); |
1147 | if (local_handle == kInvalidLocalHandle) { |
1148 | parent_->Run(run_opts, handle, args, rets, done); |
1149 | return; |
1150 | } |
1151 | |
1152 | if (run_opts.runner == nullptr) { |
1153 | run_opts.runner = &default_runner_; |
1154 | } |
1155 | DCHECK(run_opts.runner != nullptr); |
1156 | |
1157 | Item* item = nullptr; |
1158 | Status s = GetOrCreateItem(local_handle, &item); |
1159 | if (!s.ok()) { |
1160 | done(s); |
1161 | return; |
1162 | } |
1163 | |
1164 | if (run_opts.remote_execution) { |
1165 | // NOTE(mrry): `RunRemote()` will set `exec_args->call_frame` for us. |
1166 | RunRemote(run_opts, handle, args, rets, item, std::move(done)); |
1167 | return; |
1168 | } |
1169 | |
1170 | const FunctionBody* fbody = GetFunctionBody(handle); |
1171 | FunctionCallFrame* frame = |
1172 | new FunctionCallFrame(fbody->arg_types, fbody->ret_types); |
1173 | s = frame->SetArgs(args); |
1174 | if (!s.ok()) { |
1175 | delete frame; |
1176 | done(s); |
1177 | return; |
1178 | } |
1179 | |
1180 | profiler::TraceMeProducer activity( |
1181 | // To TraceMeConsumers in ExecutorState::Process/Finish. |
1182 | [&opts] { |
1183 | return profiler::TraceMeEncode("FunctionRun" , |
1184 | {{"id" , opts.step_id}, {"_r" , 1}}); |
1185 | }, |
1186 | profiler::ContextType::kTfExecutor, opts.step_id, |
1187 | profiler::TraceMeLevel::kInfo); |
1188 | |
1189 | Executor::Args exec_args; |
1190 | ExecutorArgsFromOptions(run_opts, frame, &exec_args); |
1191 | |
1192 | bool allow_dead_tensors = run_opts.allow_dead_tensors; |
1193 | item->exec->RunAsync( |
1194 | // Executor args |
1195 | exec_args, |
1196 | // Done callback. |
1197 | [frame, rets, done, allow_dead_tensors](const Status& status) { |
1198 | Status s = status; |
1199 | if (s.ok()) { |
1200 | s = frame->ConsumeRetvals(rets, allow_dead_tensors); |
1201 | } |
1202 | delete frame; |
1203 | done(s); |
1204 | }); |
1205 | } |
1206 | |
1207 | void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, |
1208 | CallFrameInterface* frame, |
1209 | DoneCallback done) { |
1210 | if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) { |
1211 | done(errors::Cancelled("" )); |
1212 | return; |
1213 | } |
1214 | |
1215 | Options run_opts = opts; |
1216 | if (opts.create_rendezvous) { |
1217 | auto* rendezvous = new RefCountedIntraProcessRendezvous(device_mgr_); |
1218 | run_opts.rendezvous = rendezvous; |
1219 | run_opts.create_rendezvous = false; |
1220 | done = [done = std::move(done), rendezvous](const Status& status) mutable { |
1221 | rendezvous->Unref(); |
1222 | done(status); |
1223 | }; |
1224 | } |
1225 | |
1226 | LocalHandle local_handle = parent_->GetHandleOnDevice( |
1227 | device_name_, handle, /*include_multi_device=*/true); |
1228 | if (local_handle == kInvalidLocalHandle) { |
1229 | parent_->Run(run_opts, handle, frame, done); |
1230 | return; |
1231 | } |
1232 | |
1233 | if (opts.remote_execution) { |
1234 | // NOTE(mrry): This bit is only set for a local function when `parent_` |
1235 | // calls back into this class, and the current implementation of |
1236 | // `ProcessFunctionLibraryRuntime` currently always uses the vector-based |
1237 | // `args`/`rets` interface. |
1238 | done(errors::Unimplemented("Remote calling with CallFrameInterface" )); |
1239 | return; |
1240 | } |
1241 | |
1242 | Item* item = nullptr; |
1243 | Status s = GetOrCreateItem(local_handle, &item); |
1244 | if (!s.ok()) { |
1245 | done(s); |
1246 | return; |
1247 | } |
1248 | if (run_opts.runner == nullptr) { |
1249 | run_opts.runner = &default_runner_; |
1250 | } |
1251 | DCHECK(run_opts.runner != nullptr); |
1252 | |
1253 | profiler::TraceMeProducer activity( |
1254 | // To TraceMeConsumers in ExecutorState::Process/Finish. |
1255 | [&opts] { |
1256 | return profiler::TraceMeEncode("FunctionRun" , |
1257 | {{"id" , opts.step_id}, {"_r" , 1}}); |
1258 | }, |
1259 | profiler::ContextType::kTfExecutor, opts.step_id, |
1260 | profiler::TraceMeLevel::kInfo); |
1261 | |
1262 | Executor::Args exec_args; |
1263 | ExecutorArgsFromOptions(run_opts, frame, &exec_args); |
1264 | item->exec->RunAsync(exec_args, std::move(done)); |
1265 | } |
1266 | |
1267 | Status FunctionLibraryRuntimeImpl::PrepareRunSync( |
1268 | Handle handle, Options* run_opts, Item** out_item, |
1269 | std::unique_ptr<PrivateIntraProcessRendezvous>* out_rendezvous) { |
1270 | if (run_opts->cancellation_manager && |
1271 | run_opts->cancellation_manager->IsCancelled()) { |
1272 | return errors::Cancelled("" ); |
1273 | } |
1274 | |
1275 | if (run_opts->remote_execution) { |
1276 | // NOTE(mrry): This bit is only set for a local function when `parent_` |
1277 | // calls back into this class, and the current implementation of |
1278 | // `ProcessFunctionLibraryRuntime` currently always uses the asynchronous |
1279 | // Run() method. |
1280 | return errors::Unimplemented("Remote calling with RunSync()" ); |
1281 | } |
1282 | |
1283 | if (run_opts->create_rendezvous) { |
1284 | *out_rendezvous = |
1285 | std::make_unique<PrivateIntraProcessRendezvous>(device_mgr_); |
1286 | run_opts->rendezvous = out_rendezvous->get(); |
1287 | run_opts->create_rendezvous = false; |
1288 | } |
1289 | |
1290 | LocalHandle local_handle = parent_->GetHandleOnDevice( |
1291 | device_name_, handle, /*include_multi_device=*/true); |
1292 | if (local_handle == kInvalidLocalHandle) { |
1293 | *out_item = nullptr; |
1294 | return OkStatus(); |
1295 | } |
1296 | |
1297 | TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, out_item)); |
1298 | |
1299 | if (run_opts->runner == nullptr) { |
1300 | run_opts->runner = &default_runner_; |
1301 | } |
1302 | DCHECK(run_opts->runner != nullptr); |
1303 | |
1304 | return OkStatus(); |
1305 | } |
1306 | |
1307 | Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle, |
1308 | gtl::ArraySlice<Tensor> args, |
1309 | std::vector<Tensor>* rets) { |
1310 | Item* item = nullptr; |
1311 | std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous; |
1312 | TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous)); |
1313 | if (item == nullptr) { |
1314 | return parent_->RunSync(opts, handle, args, rets); |
1315 | } |
1316 | |
1317 | Executor::Args exec_args; |
1318 | const FunctionBody* fbody = GetFunctionBody(handle); |
1319 | FunctionCallFrame frame(fbody->arg_types, fbody->ret_types); |
1320 | TF_RETURN_IF_ERROR(frame.SetArgs(args)); |
1321 | ExecutorArgsFromOptions(opts, &frame, &exec_args); |
1322 | |
1323 | TF_RETURN_IF_ERROR(item->exec->Run(exec_args)); |
1324 | return frame.ConsumeRetvals(rets, opts.allow_dead_tensors); |
1325 | } |
1326 | |
1327 | Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle, |
1328 | CallFrameInterface* call_frame) { |
1329 | Item* item = nullptr; |
1330 | std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous; |
1331 | TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous)); |
1332 | if (item == nullptr) { |
1333 | return parent_->RunSync(opts, handle, call_frame); |
1334 | } |
1335 | |
1336 | Executor::Args exec_args; |
1337 | ExecutorArgsFromOptions(opts, call_frame, &exec_args); |
1338 | return item->exec->Run(exec_args); |
1339 | } |
1340 | |
1341 | bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) const { |
1342 | const OpDef* op_def; |
1343 | const Status s = base_lib_def_->LookUpOpDef(func, &op_def); |
1344 | return s.ok() && op_def->is_stateful(); |
1345 | } |
1346 | |
1347 | string FunctionLibraryRuntimeImpl::DebugString(Handle handle) { |
1348 | Item* item = nullptr; |
1349 | LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); |
1350 | Status s = GetOrCreateItem(local_handle, &item); |
1351 | if (s.ok()) { |
1352 | if (item->graph) { |
1353 | return tensorflow::DebugString(item->graph.get()); |
1354 | } else { |
1355 | return tensorflow::DebugString(item->func_graph->graph); |
1356 | } |
1357 | } else { |
1358 | return s.ToString(); |
1359 | } |
1360 | } |
1361 | |
1362 | Status FunctionLibraryRuntimeImpl::Clone( |
1363 | std::unique_ptr<FunctionLibraryDefinition>* out_lib_def, |
1364 | std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr, |
1365 | FunctionLibraryRuntime** out_flr, bool skip_flib_def) { |
1366 | TF_RETURN_IF_ERROR(parent_->Clone(env_, graph_def_version_, |
1367 | optimizer_.options(), out_lib_def, out_pflr, |
1368 | skip_flib_def)); |
1369 | *out_flr = (*out_pflr)->GetFLR(device_->name()); |
1370 | if (*out_flr != nullptr) { |
1371 | return OkStatus(); |
1372 | } else { |
1373 | return errors::Internal("Cloning FunctionLibraryRuntime failed." ); |
1374 | } |
1375 | } |
1376 | |
1377 | namespace { |
1378 | |
1379 | struct CustomCreatorSingleton { |
1380 | mutex mu; |
1381 | std::unique_ptr<CustomKernelCreator> custom_creator = nullptr; |
1382 | |
1383 | void Set(CustomKernelCreator* cb) { |
1384 | mutex_lock l(mu); |
1385 | custom_creator.reset(cb); |
1386 | } |
1387 | |
1388 | CustomKernelCreator* Get() { |
1389 | mutex_lock l(mu); |
1390 | return custom_creator.get(); |
1391 | } |
1392 | }; |
1393 | |
1394 | CustomCreatorSingleton* GetCustomCreatorSingleton() { |
1395 | static CustomCreatorSingleton* ccs = new CustomCreatorSingleton; |
1396 | return ccs; |
1397 | } |
1398 | |
1399 | } // namespace |
1400 | |
1401 | const CustomKernelCreator* GetDefaultCustomKernelCreator() { |
1402 | return GetCustomCreatorSingleton()->Get(); |
1403 | } |
1404 | |
1405 | void RegisterDefaultCustomKernelCreator(CustomKernelCreator* c) { |
1406 | GetCustomCreatorSingleton()->Set(c); |
1407 | } |
1408 | |
1409 | std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime( |
1410 | const DeviceMgr* device_mgr, Env* env, const ConfigProto* config, |
1411 | Device* device, int graph_def_version, |
1412 | const FunctionLibraryDefinition* lib_def, thread::ThreadPool* thread_pool, |
1413 | const OptimizerOptions& optimizer_options, |
1414 | const SessionMetadata* session_metadata, |
1415 | ProcessFunctionLibraryRuntime* parent) { |
1416 | return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl( |
1417 | device_mgr, env, config, device, graph_def_version, lib_def, thread_pool, |
1418 | optimizer_options, session_metadata, parent)); |
1419 | } |
1420 | |
1421 | class SymbolicGradientHelper { |
1422 | public: |
1423 | explicit SymbolicGradientHelper(const FunctionBody& f) : fbody_(&f) {} |
1424 | ~SymbolicGradientHelper() = default; |
1425 | |
1426 | std::unique_ptr<FunctionBody> Compute(); |
1427 | |
1428 | private: |
1429 | const FunctionBody* fbody_; |
1430 | |
1431 | // Makes a copy of fbody_ in gbody. |
1432 | void Copy(FunctionBody* gbody); |
1433 | |
1434 | TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientHelper); |
1435 | }; |
1436 | |
1437 | void SymbolicGradientHelper::Copy(FunctionBody* gbody) { |
1438 | const Graph& src = *(fbody_->graph); |
1439 | gbody->graph = new Graph(src.op_registry()); |
1440 | Graph* dst = gbody->graph; |
1441 | |
1442 | std::vector<Node*> node_map(src.num_node_ids()); |
1443 | |
1444 | // Copy just the fdef attributes (copy '_noinline' and other similar flags to |
1445 | // the gradient function body). |
1446 | *(gbody->fdef.mutable_attr()) = fbody_->fdef.attr(); |
1447 | |
1448 | // Copy the nodes. |
1449 | node_map[src.source_node()->id()] = dst->source_node(); |
1450 | node_map[src.sink_node()->id()] = dst->sink_node(); |
1451 | for (Node* n : src.op_nodes()) { |
1452 | node_map[n->id()] = dst->CopyNode(n); |
1453 | } |
1454 | |
1455 | // Copy the edges. |
1456 | for (const Edge* e : src.edges()) { |
1457 | Node* src_copy = node_map[e->src()->id()]; |
1458 | Node* dst_copy = node_map[e->dst()->id()]; |
1459 | dst->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input()); |
1460 | } |
1461 | |
1462 | // Save inputs in copied graph. |
1463 | CHECK_EQ(fbody_->arg_types.size(), fbody_->arg_nodes.size()); |
1464 | gbody->arg_types = fbody_->arg_types; |
1465 | for (std::size_t i = 0; i < fbody_->arg_nodes.size(); ++i) { |
1466 | gbody->arg_nodes.push_back(node_map[fbody_->arg_nodes[i]->id()]); |
1467 | } |
1468 | |
1469 | // Save outputs in copied graph. |
1470 | CHECK_EQ(fbody_->ret_types.size(), fbody_->ret_nodes.size()); |
1471 | gbody->ret_types = fbody_->ret_types; |
1472 | for (std::size_t i = 0; i < fbody_->ret_nodes.size(); ++i) { |
1473 | gbody->ret_nodes.push_back(node_map[fbody_->ret_nodes[i]->id()]); |
1474 | } |
1475 | } |
1476 | |
1477 | std::unique_ptr<FunctionBody> SymbolicGradientHelper::Compute() { |
1478 | FunctionBody* gbody = new FunctionBody; |
1479 | Copy(gbody); // copy fbody_ into gbody. |
1480 | |
1481 | Graph* g = gbody->graph; |
1482 | |
1483 | const int num_y = static_cast<int>(gbody->ret_nodes.size()); |
1484 | |
1485 | // Populate 'y_node_outputs_' with node function body outputs. |
1486 | // Populate 'y_grad_nodes' with initial gradient nodes for each return node |
1487 | // of the original function body (these will be 'arg' nodes in the function |
1488 | // gradient body). |
1489 | std::vector<NodeOut> y_node_outputs; |
1490 | y_node_outputs.reserve(num_y); |
1491 | std::vector<NodeOut> y_grad_node_outputs; |
1492 | y_grad_node_outputs.reserve(num_y); |
1493 | for (int i = 0; i < num_y; ++i) { |
1494 | Node* y = gbody->ret_nodes[i]; |
1495 | y_node_outputs.push_back({y, 0}); |
1496 | DCHECK_EQ(y->type_string(), kRetOp); |
1497 | const DataType dtype = y->input_type(0); |
1498 | const int index = static_cast<int>(gbody->arg_nodes.size()); |
1499 | Node* dy = AddArg(g, dtype, index); |
1500 | gbody->arg_types.push_back(dtype); |
1501 | gbody->arg_nodes.push_back(dy); |
1502 | y_grad_node_outputs.push_back({dy, 0}); |
1503 | } |
1504 | |
1505 | // Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs'). |
1506 | const size_t num_x = fbody_->arg_nodes.size(); |
1507 | std::vector<NodeOut> x_node_outputs; |
1508 | x_node_outputs.reserve(num_x); |
1509 | for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) { |
1510 | x_node_outputs.push_back({gbody->arg_nodes[i], 0}); |
1511 | } |
1512 | |
1513 | // Call AddSymbolicGradients which will add nodes to graph 'g' that |
1514 | // compute the function gradient (adding an entry in 'x_grad_node_outputs' |
1515 | // for each node in 'x_node_outputs'). |
1516 | std::vector<NodeOut> x_grad_node_outputs; |
1517 | TF_CHECK_OK(AddSymbolicGradients(y_node_outputs, x_node_outputs, |
1518 | y_grad_node_outputs, &x_grad_node_outputs, |
1519 | g)); |
1520 | |
1521 | // Remove the old return nodes from the function body. |
1522 | for (Node* n : gbody->ret_nodes) { |
1523 | g->RemoveNode(n); |
1524 | } |
1525 | gbody->ret_types = fbody_->arg_types; |
1526 | // TODO(apassos): use the right dtype for gradients of resource variables |
1527 | for (int i = 0; i < gbody->ret_types.size(); ++i) { |
1528 | if (gbody->ret_types[i] == DT_RESOURCE) { |
1529 | gbody->ret_types[i] = DT_FLOAT; |
1530 | } |
1531 | } |
1532 | gbody->ret_nodes.clear(); |
1533 | // Add new return nodes to the function gradient body for each node |
1534 | // in 'x_grad_nodes'. |
1535 | const int arg_types_size = static_cast<int>(fbody_->arg_types.size()); |
1536 | for (int i = 0; i < arg_types_size; ++i) { |
1537 | Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index}; |
1538 | Node* ret = AddRet(g, grad, i); |
1539 | gbody->ret_nodes.push_back(ret); |
1540 | } |
1541 | |
1542 | return std::unique_ptr<FunctionBody>(gbody); |
1543 | } |
1544 | |
1545 | std::unique_ptr<FunctionBody> SymbolicGradient(const FunctionBody& f) { |
1546 | return SymbolicGradientHelper(f).Compute(); |
1547 | } |
1548 | |
1549 | } // end namespace tensorflow |
1550 | |