1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ |
16 | #define TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ |
17 | |
18 | #include <functional> |
19 | #include <unordered_map> |
20 | |
21 | // clang-format off |
22 | // Required for IS_MOBILE_PLATFORM |
23 | #include "tensorflow/core/platform/platform.h" |
24 | // clang-format on |
25 | |
26 | #include "absl/types/optional.h" |
27 | #include "absl/types/variant.h" |
28 | #include "tensorflow/core/common_runtime/composite_device.h" |
29 | #include "tensorflow/core/common_runtime/device_mgr.h" |
30 | #include "tensorflow/core/common_runtime/device_set.h" |
31 | #include "tensorflow/core/framework/function.h" |
32 | #include "tensorflow/core/framework/types.h" |
33 | #include "tensorflow/core/lib/core/status.h" |
34 | #include "tensorflow/core/protobuf/config.pb.h" |
35 | #if !defined(IS_MOBILE_PLATFORM) |
36 | #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" |
37 | #endif // IS_MOBILE_PLATFORM |
38 | |
39 | namespace tensorflow { |
40 | |
41 | class FunctionArgsInterface { |
42 | public: |
43 | virtual ~FunctionArgsInterface() {} |
44 | |
45 | virtual bool HasRemoteOrPackedInputs() const = 0; |
46 | |
47 | virtual Status GetLocalArg(const FunctionArgIndex& index, |
48 | Tensor* val) const = 0; |
49 | |
50 | virtual std::vector<Tensor> GetLocalTensors() const = 0; |
51 | |
52 | #if !defined(IS_MOBILE_PLATFORM) |
53 | virtual Status GetRemoteArg(const FunctionArgIndex& index, |
54 | eager::RemoteTensorHandle* val) const { |
55 | return errors::Unimplemented( |
56 | "Serializing a remote argument is not implemented." ); |
57 | } |
58 | #endif // IS_MOBILE_PLATFORM |
59 | }; |
60 | |
61 | // A class that stores all the FunctionLibraryRuntime objects, one per device. |
62 | class ProcessFunctionLibraryRuntime { |
63 | public: |
64 | // Creates FunctionLibraryRuntime objects for each device in the provided |
65 | // DeviceMgr. Caller needs to make sure that device_mgr, lib_def and parent |
66 | // (if provided) outlive this object. |
67 | ProcessFunctionLibraryRuntime( |
68 | const DeviceMgr* device_mgr, Env* env, const ConfigProto* config, |
69 | int graph_def_version, const FunctionLibraryDefinition* lib_def, |
70 | const OptimizerOptions& optimizer_options, |
71 | thread::ThreadPool* thread_pool = nullptr, |
72 | DistributedFunctionLibraryRuntime* parent = nullptr, |
73 | const SessionMetadata* session_metadata = nullptr, |
74 | Rendezvous::Factory rendezvous_factory = Rendezvous::Factory()); |
75 | |
76 | ~ProcessFunctionLibraryRuntime() { |
77 | // Deleting the FunctionLibraryRuntime map will delete the function handles |
78 | // registered in it, which may call ReleaseHandle in this class again to |
79 | // release their sub-function. These circular calls may cause segfault |
80 | // since the flr_map_ may have already been deleted. Explicitly releasing |
81 | // flr_map_ here and checking flr_map_ in ReleaseHandle to avoid this. |
82 | flr_map_.reset(); |
83 | } |
84 | |
85 | // Sends `tensors_to_send` from `source_device` to `target_device` using |
86 | // `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the |
87 | // Rendezvous. `device_context` should be the DeviceContext of the device |
88 | // doing the sending. `alloc_attrs` should either be empty or be the size of |
89 | // `tensors_to_send` and indicates how the input tensors are allocated. Method |
90 | // takes references on each of the `tensors_to_send`. Method doesn't block. |
91 | static Status SendTensors(const string& source_device, |
92 | const string& target_device, |
93 | const string& key_prefix, int64_t src_incarnation, |
94 | gtl::ArraySlice<Tensor> tensors_to_send, |
95 | DeviceContext* device_context, |
96 | const std::vector<AllocatorAttributes>& alloc_attrs, |
97 | RendezvousInterface* rendezvous); |
98 | |
99 | // Receives `received_tensors` from `target_device` (originally sent from |
100 | // `source_device`) using `rendezvous`. Uses `key_prefix` to construct the |
101 | // keys to be retrieved. `device_context` should be for the device receiving |
102 | // the tensors. `alloc_attrs` indicates how to allocate the received |
103 | // tensors and should either be empty or `num_tensors` in size. Method doesn't |
104 | // block and calls `done` when `num_tensors` are fetched. |
105 | static void ReceiveTensorsAsync( |
106 | const string& source_device, const string& target_device, |
107 | const string& key_prefix, int64_t src_incarnation, int64_t num_tensors, |
108 | DeviceContext* device_context, |
109 | const std::vector<AllocatorAttributes>& alloc_attrs, |
110 | RendezvousInterface* rendezvous, std::vector<Tensor>* received_tensors, |
111 | StatusCallback done); |
112 | |
113 | static const char kDefaultFLRDevice[]; |
114 | // Returns the FunctionLibraryRuntime for the corresponding device_name. |
115 | FunctionLibraryRuntime* GetFLR(const string& device_name) const; |
116 | |
117 | // Returns the return types for the function identified by handle `h`. |
118 | Status GetRetTypes(FunctionLibraryRuntime::Handle h, |
119 | DataTypeVector* ret_types); |
120 | |
121 | // Returns the device incarnation for the given device_name. |
122 | Status GetDeviceIncarnation(const string& device_name, |
123 | int64_t* incarnation) const; |
124 | |
125 | // For a given canonicalized key signature of the function instantiated |
126 | // on device `device_name` and a `local_handle`, creates a handle and returns |
127 | // that value. Uses core/common_runtime/framework/function.h::Canonicalize |
128 | // to canonicalize the function signature. |
129 | FunctionLibraryRuntime::Handle AddHandle( |
130 | const string& function_key, const string& device_name, |
131 | FunctionLibraryRuntime::LocalHandle local_handle); |
132 | |
133 | // Returns a handle if found for the given key, else returns kInvalidHandle. |
134 | FunctionLibraryRuntime::Handle GetHandle(const string& function_key) const; |
135 | |
136 | // For the given handle instantiated on device `device_name` returns the local |
137 | // index of instantiation of that function. If the function was not |
138 | // instantiated on `device_name` or the function is multi-device, |
139 | // returns kInvalidLocalHandle. |
140 | // |
141 | // If `include_multi_device` is true and `handle` is a multi-device function |
142 | // with a single component that is placed on `device_name`, then this method |
143 | // will return the local handle for that component. |
144 | FunctionLibraryRuntime::LocalHandle GetHandleOnDevice( |
145 | const string& device_name, FunctionLibraryRuntime::Handle handle, |
146 | bool include_multi_device = false) const; |
147 | |
148 | // Fills `output_devices` with the devices on which the results will |
149 | // be produced. If some output is produced on CPU, the corresponding Device* |
150 | // is set to nullptr. If some output is DT_RESOURCE, the corresponding Device* |
151 | // is set to the device backing the resource. |
152 | // REQUIRES: `handle` identifies a multi-device function. |
153 | Status GetOutputDevices(FunctionLibraryRuntime::Handle handle, |
154 | std::vector<Device*>* output_devices) const; |
155 | |
156 | // Instantiates the function. See framework/function.h for more details. |
157 | // Allows for function_name to be instantiated on different devices |
158 | // as specified in attrs. |
159 | Status Instantiate(const string& function_name, AttrSlice attrs, |
160 | const FunctionLibraryRuntime::InstantiateOptions& options, |
161 | FunctionLibraryRuntime::Handle* handle); |
162 | |
163 | // Returns whether the function represented by the given handle needs to |
164 | // execute cross process. |
165 | Status IsCrossProcess(FunctionLibraryRuntime::Handle handle, |
166 | bool* is_cross_process) const; |
167 | |
168 | // TODO(iga): Reword |
169 | // Pins each arg that emits a `DT_RESOURCE` tensor to the device on which the |
170 | // corresponding resource lives. This ensures that the Placer assigns ops that |
171 | // access these resources to the appropriate devices. |
172 | static Status PinArgsAndRets(const std::vector<string>& input_devices, |
173 | const std::vector<string>& output_devices, |
174 | const DeviceSet& device_set, |
175 | const std::vector<Node*>& arg_nodes, |
176 | const std::vector<Node*>& ret_nodes, |
177 | const FunctionLibraryDefinition* lib_def, |
178 | Device* default_device); |
179 | |
180 | // Delegates to the local FLR that owns state corresponding to `handle` and |
181 | // tells it to release it. If the `handle` isn't needed at all, the local FLR |
182 | // might call RemoveHandle on this to get rid of the state owned by the Proc |
183 | // FLR. |
184 | // For multi-device functions, calls ReleaseHandle on local FLRs for each |
185 | // component function that is part of this multi-device function. |
186 | // Each local FLR might call RemoveHandle on this. |
187 | Status ReleaseHandle(FunctionLibraryRuntime::Handle handle); |
188 | |
189 | // Runs the function with given `handle`. Function could have been |
190 | // instantiated on any device. More details in framework/function.h |
191 | void Run(const FunctionLibraryRuntime::Options& opts, |
192 | FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args, |
193 | std::vector<Tensor>* rets, |
194 | FunctionLibraryRuntime::DoneCallback done) const; |
195 | void Run(const FunctionLibraryRuntime::Options& opts, |
196 | FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame, |
197 | FunctionLibraryRuntime::DoneCallback done) const; |
198 | |
199 | void Run(const FunctionLibraryRuntime::Options& opts, |
200 | FunctionLibraryRuntime::Handle handle, |
201 | const FunctionArgsInterface& args, std::vector<FunctionRet>* rets, |
202 | FunctionLibraryRuntime::DoneCallback done) const; |
203 | |
204 | Status RunSync(const FunctionLibraryRuntime::Options& opts, |
205 | FunctionLibraryRuntime::Handle handle, |
206 | gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets) const; |
207 | Status RunSync(const FunctionLibraryRuntime::Options& opts, |
208 | FunctionLibraryRuntime::Handle handle, |
209 | CallFrameInterface* frame) const; |
210 | |
211 | const DeviceMgr* device_mgr() { return device_mgr_; } |
212 | |
213 | const std::shared_ptr<DeviceSet> device_set() const { |
214 | tf_shared_lock l(mu_); |
215 | return device_set_; |
216 | } |
217 | |
218 | // Initialize the set of local and remote devices and corresponding flr for op |
219 | // device selection. |
220 | void InitializeDeviceAndFlr(); |
221 | |
222 | const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; } |
223 | |
224 | const FunctionLibraryDefinition* GetFunctionLibraryDefinition() const { |
225 | return lib_def_; |
226 | } |
227 | |
228 | // Add a CompositeDevice to `device_set_` |
229 | void AddCompositeDevice(CompositeDevice* d) TF_LOCKS_EXCLUDED(mu_) { |
230 | mutex_lock l(mu_); |
231 | device_set_->AddDevice(d); |
232 | composite_devices_.push_back(d); |
233 | } |
234 | |
235 | protected: |
236 | friend class FunctionLibraryRuntimeImpl; |
237 | |
238 | struct InternalArgs { |
239 | std::vector<FunctionArg> args; |
240 | #if !defined(IS_MOBILE_PLATFORM) |
241 | // Holds the RemoteTensorHandles referred by args. |
242 | std::vector<std::unique_ptr<eager::RemoteTensorHandle>> remote_args; |
243 | #endif // IS_MOBILE_PLATFORM |
244 | }; |
245 | |
246 | // Structure detailing the asynchronous assumptions of a component function, |
247 | // such as whether it can support synchronous execution and any information |
248 | // needed to execute in proper order to resolve inter-subgraph dependencies. |
249 | class AsyncAttributes { |
250 | public: |
251 | enum Summary { kSafeForSync = 0, kSendOnly, kRecvOnly, kAsyncRequired }; |
252 | |
253 | AsyncAttributes() |
254 | : allow_control_flow_sync_execution_(false), summary_(kSafeForSync) {} |
255 | explicit AsyncAttributes(const Graph* graph, |
256 | bool allow_control_flow_sync_execution) |
257 | : allow_control_flow_sync_execution_(allow_control_flow_sync_execution), |
258 | summary_(Summarize(graph)) {} |
259 | Summary summary() const { return summary_; } |
260 | bool allow_control_flow_sync_execution() const { |
261 | return allow_control_flow_sync_execution_; |
262 | } |
263 | |
264 | private: |
265 | // This data member should be initialized before the summary_. |
266 | bool allow_control_flow_sync_execution_; |
267 | Summary summary_; |
268 | Summary Summarize(const Graph* graph); |
269 | }; |
270 | |
271 | // Structure to keep track of how a component function (a single-device |
272 | // piece of a multi-device function) fits into the multi-device function. |
273 | struct ComponentFunctionData { |
274 | // The handle for the instantiated component function. |
275 | FunctionLibraryRuntime::Handle handle; |
276 | // arg_indices.size() is the number of arguments to the component function. |
277 | // The i-th argument of the component function comes from the |
278 | // `arg_indices[i]`-th argument of the multi-device function. |
279 | std::vector<FunctionArgIndex> arg_indices; |
280 | // ret_indices.size() is the number of return values of the component |
281 | // function. The i-th return value of the component function goes to the |
282 | // `ret_indices[i]`-th return value of the multi-device function. |
283 | std::vector<int> ret_indices; |
284 | // arg_alloc_attrs[i] are the allocator attributes of the i-th argument to |
285 | // the component function. |
286 | std::vector<AllocatorAttributes> arg_alloc_attrs; |
287 | // ret_alloc_attrs[i] are the allocator attributes of the i-th return value |
288 | // of the component function. |
289 | std::vector<AllocatorAttributes> ret_alloc_attrs; |
290 | |
291 | AsyncAttributes async_attributes; |
292 | }; |
293 | |
294 | // Data structure holding information for a single instantiated multi-device |
295 | // function. |
296 | // The fields are filled in during instantiation. Once the object is |
297 | // added to mdevice_data_, all fields are constant. |
298 | struct MultiDeviceFunctionData { |
299 | MultiDeviceFunctionData(const string& function_name, |
300 | const string& function_key, int num_outputs, |
301 | FunctionLibraryDefinition&& lib_def, |
302 | DataTypeVector ret_types) |
303 | : function_name_(function_name), |
304 | function_key_(function_key), |
305 | instantiation_counter_(1), |
306 | lib_def_(std::move(lib_def)), |
307 | num_outputs_(num_outputs), |
308 | ret_types_(std::move(ret_types)), |
309 | is_cross_process_(false), |
310 | has_remote_outputs(false) {} |
311 | |
312 | const string function_name_; |
313 | const string function_key_; |
314 | uint64 instantiation_counter_; |
315 | // A library that contains definitions of component functions and their |
316 | // transitive dependencies. |
317 | FunctionLibraryDefinition lib_def_; |
318 | // Stored here to resize the output tensor vector when function is run. |
319 | const int num_outputs_; |
320 | DataTypeVector ret_types_; |
321 | |
322 | // Indicates whether this function needs to execute cross process. |
323 | bool is_cross_process_; |
324 | // Indicates whether this function has remote outputs. |
325 | bool has_remote_outputs; |
326 | |
327 | // Indicates if running this function synchronously is both allowed + safe. |
328 | bool enable_sync_execution; |
329 | |
330 | // Maps the device name to the information about the component function |
331 | // be run on this device. |
332 | std::unordered_map<string, ComponentFunctionData> glue_; |
333 | }; |
334 | |
335 | struct CleanUpItem { |
336 | string device; |
337 | uint64 step_id; |
338 | FunctionLibraryRuntime::Handle local_handle; |
339 | }; |
340 | |
341 | // If `handle` represents a multi-device function, returns the multi-device |
342 | // data associated with `handle`. Else, nullptr. |
343 | MultiDeviceFunctionData* IsMultiDevice( |
344 | FunctionLibraryRuntime::Handle handle) const; |
345 | |
346 | DistributedFunctionLibraryRuntime* const parent_; |
347 | |
348 | private: |
349 | FunctionLibraryRuntime::Handle AddHandleLocked( |
350 | const string& function_key, const string& device_name, |
351 | FunctionLibraryRuntime::LocalHandle local_handle) |
352 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
353 | |
354 | // For a given device_name, returns a DeviceContext for copying |
355 | // tensors to/from the device. |
356 | Status GetDeviceContext(const string& device_name, |
357 | DeviceContext** device_context) const; |
358 | |
359 | // Looks up the information for the given `handle` and returns the name |
360 | // of the device where the function is registered. |
361 | string GetDeviceName(FunctionLibraryRuntime::Handle handle) const; |
362 | |
363 | // Removes handle from the state owned by this object. |
364 | Status RemoveHandle(FunctionLibraryRuntime::Handle handle); |
365 | |
366 | // Clones ProcessFunctionLibraryRuntime and FunctionLibraryDefinition |
367 | // (transferring ownership of both to the caller). Note that the |
368 | // ProcessFunctionLibraryRuntime borrows a pointer to the |
369 | // FunctionLibraryDefinition and so the FunctionLibraryDefinition should |
370 | // outlive the ProcessFunctionLibraryRuntime. |
371 | // |
372 | // The `skip_flib_def` argument controls whether the method should clone the |
373 | // FunctionLibraryDefinition (default behavior) or return an empty function |
374 | // library. The latter is used by tf.data, which manages |
375 | // FunctionLibraryDefinitions for its functions independently (and passes |
376 | // these into the FunctionLibraryRuntime through an overlay), to avoid linear |
377 | // runtime w.r.t. to number of functions in the current function library. |
378 | Status Clone(Env* env, int graph_def_version, |
379 | const OptimizerOptions& optimizer_options, |
380 | std::unique_ptr<FunctionLibraryDefinition>* out_lib_def, |
381 | std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr, |
382 | bool skip_flib_def = false) const; |
383 | |
384 | Status ReleaseMultiDeviceHandle(FunctionLibraryRuntime::Handle handle); |
385 | |
386 | // Function graph related information after optimizations. |
387 | struct OptimizedFunctionGraphInfo { |
388 | // Optimized graph. |
389 | std::unique_ptr<Graph> graph; |
390 | // Optimized function library. |
391 | FunctionLibraryDefinition lib_def; |
392 | // Map from original node names to control return names. |
393 | std::unordered_map<string, string> node_name_to_control_ret; |
394 | // Return node types of the function. |
395 | DataTypeVector ret_types; |
396 | // Number of return nodes. |
397 | size_t num_return_nodes; |
398 | }; |
399 | |
400 | // Outputs graph optimization result after all the graph optimization (up till |
401 | // before graph partitioning); returns error if optimization fails. |
402 | StatusOr<OptimizedFunctionGraphInfo> OptimizeFunctionGraph( |
403 | const string& function_name, AttrSlice attrs, |
404 | const FunctionLibraryRuntime::InstantiateOptions& options, |
405 | const std::shared_ptr<DeviceSet>& dev_set); |
406 | |
407 | Status InstantiateMultiDevice( |
408 | const string& function_name, AttrSlice attrs, |
409 | const FunctionLibraryRuntime::InstantiateOptions& options, |
410 | FunctionLibraryRuntime::Handle* handle); |
411 | |
412 | void InstantiateRemote( |
413 | const string& function_name, AttrSlice attrs, |
414 | const FunctionLibraryRuntime::InstantiateOptions& options, |
415 | FunctionLibraryRuntime::Handle* handle, |
416 | FunctionLibraryRuntime::DoneCallback done); |
417 | |
418 | FunctionLibraryRuntime::Handle AddMultiDeviceHandle( |
419 | const std::unique_ptr<MultiDeviceFunctionData> data, |
420 | const string& function_key); |
421 | |
422 | bool HasMultiDeviceHandle(FunctionLibraryRuntime::Handle handle) const; |
423 | |
424 | void RunInternal(const FunctionLibraryRuntime::Options& opts, |
425 | FunctionLibraryRuntime::Handle handle, |
426 | gtl::ArraySlice<FunctionArg> args, |
427 | std::vector<FunctionRet>* rets, |
428 | std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items, |
429 | FunctionLibraryRuntime::DoneCallback done) const; |
430 | |
431 | Status CreateRendezvous(FunctionLibraryRuntime::Options& opts, |
432 | Rendezvous** created_rendezvous) const; |
433 | |
434 | void CleanupCreatedRendezvous(const Rendezvous* created_rendezvous, |
435 | const int64_t step_id) const; |
436 | |
437 | FunctionLibraryRuntime::DoneCallback ApplyCleanUpToDoneCallback( |
438 | std::vector<std::unique_ptr<CleanUpItem>>* items, |
439 | FunctionLibraryRuntime::DoneCallback done, const int64_t step_id, |
440 | const Rendezvous* rendezvous) const; |
441 | |
442 | void CleanUp(std::vector<std::unique_ptr<CleanUpItem>>* items, |
443 | FunctionLibraryRuntime::DoneCallback done) const; |
444 | |
445 | static Status GetComponentArgs(gtl::ArraySlice<Tensor> args, |
446 | const ComponentFunctionData& comp_data, |
447 | InternalArgs* comp_args); |
448 | |
449 | #if !defined(IS_MOBILE_PLATFORM) |
450 | static Status GetComponentArgs(const FunctionArgsInterface& args, |
451 | const ComponentFunctionData& comp_data, |
452 | InternalArgs* comp_args); |
453 | #endif // IS_MOBILE_PLATFORM |
454 | |
455 | std::vector<string> GetOrderedSubgraphs( |
456 | const MultiDeviceFunctionData* data) const; |
457 | |
458 | Status PrepareRunMultiDevice(const FunctionLibraryRuntime::Options& opts, |
459 | FunctionLibraryRuntime::Handle handle, |
460 | const MultiDeviceFunctionData** data) const; |
461 | |
462 | Status RunMultiDeviceSync( |
463 | const FunctionLibraryRuntime::Options& opts, |
464 | FunctionLibraryRuntime::Handle handle, std::vector<FunctionRet>* rets, |
465 | std::function<Status(const ComponentFunctionData& comp_data, |
466 | InternalArgs* args)> |
467 | get_component_args) const; |
468 | |
469 | void RunMultiDeviceAsync( |
470 | const FunctionLibraryRuntime::Options& opts, |
471 | FunctionLibraryRuntime::Handle handle, std::vector<FunctionRet>* rets, |
472 | std::vector<std::unique_ptr<CleanUpItem>>* cleanup_items, |
473 | FunctionLibraryRuntime::DoneCallback done, |
474 | std::function<Status(const ComponentFunctionData& comp_data, |
475 | InternalArgs* args)> |
476 | get_component_args) const; |
477 | |
478 | // Data structure holding information for a single instantiated remote |
479 | // (to be executed on `target_device`) function. |
480 | class FunctionData { |
481 | public: |
482 | FunctionData(const string& target_device, |
483 | FunctionLibraryRuntime::LocalHandle local_handle, |
484 | const string& function_key) |
485 | : target_device_(target_device), |
486 | local_handle_(local_handle), |
487 | function_key_(function_key) {} |
488 | |
489 | const string& target_device() { return target_device_; } |
490 | const string& function_key() { return function_key_; } |
491 | |
492 | FunctionLibraryRuntime::LocalHandle local_handle() { |
493 | mutex_lock l(mu_); |
494 | return local_handle_; |
495 | } |
496 | |
497 | // Initializes the FunctionData object by potentially making an Initialize |
498 | // call to the DistributedFunctionLibraryRuntime. |
499 | void DistributedInit( |
500 | DistributedFunctionLibraryRuntime* parent, const string& function_name, |
501 | const FunctionLibraryDefinition& lib_def, AttrSlice attrs, |
502 | const FunctionLibraryRuntime::InstantiateOptions& options, |
503 | FunctionLibraryRuntime::DoneCallback done); |
504 | |
505 | bool is_cross_process() { |
506 | mutex_lock l(mu_); |
507 | return is_cross_process_; |
508 | } |
509 | |
510 | private: |
511 | mutex mu_; |
512 | |
513 | const string target_device_; |
514 | FunctionLibraryRuntime::LocalHandle local_handle_ TF_GUARDED_BY(mu_); |
515 | const string function_key_; |
516 | bool is_cross_process_ TF_GUARDED_BY(mu_) = false; |
517 | bool init_started_ TF_GUARDED_BY(mu_) = false; |
518 | Status init_result_ TF_GUARDED_BY(mu_); |
519 | Notification init_done_; |
520 | }; |
521 | |
522 | mutable mutex mu_; |
523 | |
524 | Env* const env_; |
525 | const absl::optional<const ConfigProto> config_; |
526 | const DeviceMgr* const device_mgr_; |
527 | const FunctionLibraryDefinition* lib_def_; |
528 | thread::ThreadPool* default_thread_pool_; |
529 | |
530 | // Cluster update can reinitialize the device_set_ due to remote device |
531 | // changes. At the same time, InstantiateMultiDevice can use the cached |
532 | // devices to instantiate multi-worker functions. Function instantiation would |
533 | // fail if it spans the changed remote devices. |
534 | std::shared_ptr<DeviceSet> device_set_ TF_GUARDED_BY(mu_); |
535 | |
536 | // Composite devices owned by a EagerContext. |
537 | std::vector<CompositeDevice*> composite_devices_ TF_GUARDED_BY(mu_); |
538 | |
539 | // Holds all the function instantiations. Maps function_keys to handles. |
540 | std::unordered_map<string, FunctionLibraryRuntime::Handle> table_ |
541 | TF_GUARDED_BY(mu_); |
542 | |
543 | // Function data for instantiated remote functions. |
544 | std::unordered_map<FunctionLibraryRuntime::Handle, |
545 | std::unique_ptr<FunctionData>> |
546 | function_data_ TF_GUARDED_BY(mu_); |
547 | |
548 | // Function data for instantiated multi-device functions. |
549 | std::unordered_map<FunctionLibraryRuntime::Handle, |
550 | std::unique_ptr<MultiDeviceFunctionData>> |
551 | mdevice_data_ TF_GUARDED_BY(mu_); |
552 | |
553 | std::unique_ptr< |
554 | std::unordered_map<Device*, std::unique_ptr<FunctionLibraryRuntime>>> |
555 | flr_map_; |
556 | int next_handle_ TF_GUARDED_BY(mu_); |
557 | const SessionMetadata* const session_metadata_; |
558 | const Rendezvous::Factory rendezvous_factory_; |
559 | |
560 | const OptimizerOptions optimizer_options_; |
561 | const int graph_def_version_; |
562 | }; |
563 | |
564 | } // namespace tensorflow |
565 | |
566 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ |
567 | |