1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_SESSION_H_ |
17 | #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_SESSION_H_ |
18 | |
19 | #include <string> |
20 | |
21 | #include "tensorflow/core/common_runtime/device_mgr.h" |
22 | #include "tensorflow/core/distributed_runtime/cluster_function_library_runtime.h" |
23 | #include "tensorflow/core/distributed_runtime/graph_mgr.h" |
24 | #include "tensorflow/core/distributed_runtime/worker_cache.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | class ClusterFunctionLibraryRuntime; |
29 | class GraphMgr; |
30 | class WorkerCacheInterface; |
31 | |
32 | // WorkerSession encapsulates all of the state relating to a given session. |
33 | class WorkerSession { |
34 | public: |
35 | // Collection of local devices. These devices are typically |
36 | // RenamedDevices in all except the SessionMgr.legacy_session_ and |
37 | // sessions created with `isolate_session_state == false`. In the |
38 | // those cases, this method returns a pointer to a borrowed |
39 | // DeviceMgr (typically the `worker_env.device_mgr`). |
40 | DeviceMgr* device_mgr() { |
41 | return device_mgr_ ? device_mgr_.get() : borrowed_device_mgr_; |
42 | } |
43 | |
44 | DynamicDeviceMgr* remote_device_mgr() { return remote_device_mgr_.get(); } |
45 | |
46 | const string& session_name() const { return session_name_; } |
47 | const string& worker_name() const { return worker_name_; } |
48 | |
49 | WorkerCacheInterface* worker_cache() const { |
50 | tf_shared_lock l(worker_session_state_mu_); |
51 | return worker_cache_.get(); |
52 | } |
53 | GraphMgr* graph_mgr() const { return graph_mgr_.get(); } |
54 | |
55 | ClusterFunctionLibraryRuntime* cluster_flr() const { |
56 | return cluster_flr_.get(); |
57 | } |
58 | |
59 | WorkerSession(const string& session_name, const string& worker_name, |
60 | std::unique_ptr<WorkerCacheInterface> worker_cache, |
61 | std::unique_ptr<DeviceMgr> device_mgr, |
62 | std::unique_ptr<GraphMgr> graph_mgr, |
63 | std::unique_ptr<DynamicDeviceMgr> remote_device_mgr); |
64 | |
65 | static std::shared_ptr<WorkerSession> CreateWithBorrowedDeviceMgr( |
66 | const string& session_name, const string& worker_name, |
67 | std::unique_ptr<WorkerCacheInterface> worker_cache, |
68 | DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr, |
69 | std::unique_ptr<DynamicDeviceMgr> remote_device_mgr); |
70 | |
71 | // In the eager runtime we allow WorkerSession to be updated, where the |
72 | // worker cache will be recreated. If WorkerSession upate is expected and a |
73 | // worker in the cache is used in RPCs, the caller should hold a shared |
74 | // pointer to avoid the workers getting deleted. |
75 | std::shared_ptr<WorkerCacheInterface> GetSharedWorkerCache() { |
76 | tf_shared_lock l(worker_session_state_mu_); |
77 | return worker_cache_; |
78 | } |
79 | |
80 | // Update an existing worker session with new set of remote workers and |
81 | // devices. Added devices will be owned by the worker session, and removed |
82 | // devices will be freed by their names. |
83 | Status UpdateWorkerCacheAndDevices( |
84 | std::unique_ptr<WorkerCacheInterface> new_worker_cache, |
85 | std::vector<std::unique_ptr<Device>> added_remote_devices, |
86 | const std::vector<Device*>& removed_remote_devices); |
87 | |
88 | ~WorkerSession(); |
89 | |
90 | private: |
91 | WorkerSession(const string& session_name, const string& worker_name, |
92 | std::unique_ptr<WorkerCacheInterface> worker_cache, |
93 | DeviceMgr* borrowed_device_mgr, |
94 | std::unique_ptr<GraphMgr> graph_mgr, |
95 | std::unique_ptr<DynamicDeviceMgr> remote_device_mgr); |
96 | |
97 | // The name of the session. |
98 | const string session_name_; |
99 | |
100 | // The name of the worker. E.g., /job:mnist/replica:0/task:1. |
101 | const string worker_name_; |
102 | |
103 | mutable mutex worker_session_state_mu_; |
104 | // Object from which WorkerInterface instances can be obtained. |
105 | std::shared_ptr<WorkerCacheInterface> worker_cache_ |
106 | TF_GUARDED_BY(worker_session_state_mu_); |
107 | |
108 | // graph_mgr keeps track of the registered graphs of this session. |
109 | // |
110 | // Note: graph_mgr must be deleted before rendezvous_mgr! |
111 | // Note: graph_mgr must be deleted before device_mgr! |
112 | const std::unique_ptr<GraphMgr> graph_mgr_; |
113 | |
114 | std::unique_ptr<ClusterFunctionLibraryRuntime> cluster_flr_; |
115 | |
116 | const std::unique_ptr<DeviceMgr> device_mgr_; |
117 | DeviceMgr* const borrowed_device_mgr_; // Not owned. |
118 | std::unique_ptr<DynamicDeviceMgr> remote_device_mgr_; |
119 | }; |
120 | |
121 | } // namespace tensorflow |
122 | |
123 | #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_SESSION_H_ |
124 | |