1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ |
16 | #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ |
17 | |
18 | #include "absl/types/optional.h" |
19 | #include "tensorflow/core/distributed_runtime/worker_cache.h" |
20 | #include "tensorflow/core/distributed_runtime/worker_interface.h" |
21 | #include "tensorflow/core/framework/function.h" |
22 | |
23 | namespace tensorflow { |
24 | |
25 | class WorkerSession; |
26 | |
27 | // ClusterFunctionLibraryRuntime contains methods to Instantiate and Run |
28 | // functions across processes by making RPCs through worker service. |
29 | class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime { |
30 | public: |
31 | ClusterFunctionLibraryRuntime(WorkerSession* worker_session, |
32 | bool create_worker_session_called, |
33 | DeviceMgr* remote_device_mgr) |
34 | : worker_session_(worker_session), |
35 | create_worker_session_called_(create_worker_session_called), |
36 | remote_device_mgr_(remote_device_mgr) {} |
37 | |
38 | ~ClusterFunctionLibraryRuntime() override; |
39 | |
40 | void Instantiate(const string& function_name, |
41 | const FunctionLibraryDefinition& lib_def, AttrSlice attrs, |
42 | const FunctionLibraryRuntime::InstantiateOptions& options, |
43 | FunctionLibraryRuntime::LocalHandle* handle, |
44 | FunctionLibraryRuntime::DoneCallback done) override; |
45 | |
46 | void Run(const FunctionLibraryRuntime::Options& opts, |
47 | FunctionLibraryRuntime::LocalHandle handle, |
48 | gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets, |
49 | FunctionLibraryRuntime::DoneCallback done) override; |
50 | |
51 | void Run(const FunctionLibraryRuntime::Options& opts, |
52 | FunctionLibraryRuntime::LocalHandle handle, |
53 | gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets, |
54 | FunctionLibraryRuntime::DoneCallback done) override; |
55 | |
56 | void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle, |
57 | FunctionLibraryRuntime::DoneCallback done) override; |
58 | |
59 | DeviceMgr* remote_device_mgr() const override { return remote_device_mgr_; } |
60 | |
61 | private: |
62 | static Status ConstructFunctionGraph( |
63 | const OpDef& sig, AttrSlice attrs, |
64 | const FunctionLibraryRuntime::InstantiateOptions& options, |
65 | const FunctionLibraryDefinition& flib_def, GraphDef* g, |
66 | std::vector<string>* send_keys, std::vector<string>* recv_keys); |
67 | friend class ClusterFunctionLibraryRuntimeTest; |
68 | |
69 | mutable mutex mu_; |
70 | WorkerSession* const worker_session_ = nullptr; // not owned. |
71 | const bool create_worker_session_called_; |
72 | |
73 | DeviceMgr* remote_device_mgr_; // not owned. |
74 | |
75 | struct FunctionData { |
76 | const string graph_handle; |
77 | const string target; |
78 | // Hold a shared pointer to the underlying worker cache to avoid it being |
79 | // deleted in potential cluster update. |
80 | const std::shared_ptr<WorkerCacheInterface> worker_cache; |
81 | WorkerInterface* wi = nullptr; |
82 | const std::vector<string> send_keys; |
83 | const std::vector<string> recv_keys; |
84 | |
85 | FunctionData(const string& graph_handle, const string& target, |
86 | std::shared_ptr<WorkerCacheInterface> worker_cache, |
87 | WorkerInterface* wi, const std::vector<string>& send_keys, |
88 | const std::vector<string>& recv_keys) |
89 | : graph_handle(graph_handle), |
90 | target(target), |
91 | worker_cache(std::move(worker_cache)), |
92 | wi(wi), |
93 | send_keys(send_keys), |
94 | recv_keys(recv_keys) {} |
95 | }; |
96 | |
97 | std::vector<FunctionData> function_data_ TF_GUARDED_BY(mu_); |
98 | }; |
99 | |
100 | } // namespace tensorflow |
101 | |
102 | #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ |
103 | |