1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ |
17 | #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ |
18 | |
19 | // GrpcServer manages the lifecycle of an Eager, Worker and Master service. |
20 | |
21 | #include <memory> |
22 | |
23 | #include "grpcpp/grpcpp.h" |
24 | #include "grpcpp/security/credentials.h" |
25 | #include "tensorflow/core/common_runtime/eager/context.h" |
26 | #include "tensorflow/core/common_runtime/process_util.h" |
27 | #include "tensorflow/core/common_runtime/stats_publisher_interface.h" |
28 | #include "tensorflow/core/distributed_runtime/master_env.h" |
29 | #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" |
30 | #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" |
31 | #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" |
32 | #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" |
33 | #include "tensorflow/core/distributed_runtime/server_lib.h" |
34 | #include "tensorflow/core/distributed_runtime/session_mgr.h" |
35 | #include "tensorflow/core/distributed_runtime/worker_env.h" |
36 | #include "tensorflow/core/framework/collective.h" |
37 | #include "tensorflow/core/framework/op.h" |
38 | #include "tensorflow/core/platform/env.h" |
39 | #include "tensorflow/core/profiler/profiler_service.grpc.pb.h" |
40 | |
41 | namespace tensorflow { |
42 | |
43 | class GrpcWorker; |
44 | class Master; |
45 | |
46 | // function that creates a RendezvousMgr. |
47 | typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)> |
48 | RendezvousMgrCreationFunction; |
49 | |
50 | // function that creates a CollectiveExecutorMgr. |
51 | typedef std::function<CollectiveExecutorMgrInterface*( |
52 | const ConfigProto&, const WorkerEnv*, WorkerCacheInterface*)> |
53 | CollectiveMgrCreationFunction; |
54 | |
55 | // function that registers a service to the server. The service needs to |
56 | // be registered before builder.BuildAndStart(). |
57 | typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)> |
58 | ServiceInitFunction; |
59 | |
60 | // function that creates a grpc based worker implementation. |
61 | typedef std::function<std::unique_ptr<GrpcWorker>(WorkerEnv*, |
62 | const ConfigProto& config)> |
63 | WorkerCreationFunction; |
64 | |
65 | struct GrpcServerOptions { |
66 | ServiceInitFunction service_func = nullptr; |
67 | RendezvousMgrCreationFunction rendezvous_mgr_func = nullptr; |
68 | CollectiveMgrCreationFunction collective_mgr_func = nullptr; |
69 | WorkerCreationFunction worker_func = nullptr; |
70 | StatsPublisherFactory stats_factory = CreateNoOpStatsPublisher; |
71 | GrpcWorkerServiceOptions worker_service_options; |
72 | DeviceMgr* local_device_mgr = nullptr; |
73 | }; |
74 | |
75 | class GrpcServer : public ServerInterface { |
76 | protected: |
77 | GrpcServer(const ServerDef& server_def, Env* env); |
78 | GrpcServer(const ServerDef& server_def, DeviceMgr* local_device_mgr, |
79 | Env* env); |
80 | // Allow children classes to override this and provide custom args to the |
81 | // server before it is constructed. Default behavior is to do nothing. |
82 | // requested_port provides the port requested by caller as bound_port() is |
83 | // not available till BuildAndStart has been called. |
84 | virtual void MaybeMutateBuilder(::grpc::ServerBuilder* builder, |
85 | int requested_port) {} |
86 | |
87 | public: |
88 | static Status Create(const ServerDef& server_def, Env* env, |
89 | std::unique_ptr<ServerInterface>* out_server); |
90 | static Status Create(const ServerDef& server_def, Env* env, |
91 | std::unique_ptr<GrpcServer>* out_server); |
92 | // Reuse the local_device_mgr. |
93 | static Status Create(const ServerDef& server_def, Env* env, |
94 | DeviceMgr* local_device_mgr, |
95 | std::unique_ptr<ServerInterface>* out_server); |
96 | |
97 | // Destruction is only supported in the factory method. Clean |
98 | // shutdown is not currently implemented for this server type. |
99 | virtual ~GrpcServer(); |
100 | |
101 | // Implementations of ServerInterface methods. |
102 | Status Start() override; |
103 | Status Stop() override; |
104 | Status Join() override; |
105 | const string target() const override; |
106 | |
107 | WorkerEnv* worker_env() override { return &worker_env_; } |
108 | MasterEnv* master_env() override { return &master_env_; } |
109 | |
110 | // Add master eager context to local eager service in order to handle enqueue |
111 | // requests from remote workers. |
112 | Status AddMasterEagerContextToEagerService( |
113 | const tensorflow::uint64 context_id, |
114 | tensorflow::EagerContext* context) override; |
115 | // Update the set of workers that can be reached by the GRPC server |
116 | Status UpdateServerDef(const ServerDef& server_def) override; |
117 | // Pass coordination service agent instance to server's RPC handler |
118 | Status SetCoordinationServiceAgentInstance( |
119 | CoordinationServiceAgent* agent) override; |
120 | // TODO(hanyangtay): Remove this method once gRPC server clean shutdown is |
121 | // supported. |
122 | Status StopCoordinationService() override; |
123 | |
124 | protected: |
125 | virtual Status GetHostAndPort(const ServerDef& server_def, string* host_name, |
126 | int* port) const; |
127 | Status Init(const GrpcServerOptions& opts = GrpcServerOptions()); |
128 | |
129 | // A subclass can override this method to support secure credentials. |
130 | virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials( |
131 | const ServerDef& server_def) const; |
132 | |
133 | virtual ChannelCreationFunction GetChannelCreationFunction() const; |
134 | |
135 | virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env); |
136 | |
137 | // Creates a WorkerCacheInterface for a session. |
138 | virtual Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options, |
139 | WorkerCacheInterface** worker_cache); |
140 | |
141 | // Override to return extra services to be brought up and managed along with |
142 | // the standard {master, worker, eager} services. The map key is an aribtrary |
143 | // string and the value is a pointer to the service to be brought up. |
144 | // Ownership of the pointer is transferred to GrpcServer after this call |
145 | // returns, and the service will be destroyed during the destruction of |
146 | // GrpcServer. Each service will have its HandleRPCsLoop called in a separate |
147 | // thread. An example usage would be to add a RDMA based partial worker |
148 | // service to offload tensor and data buffer transfers. |
149 | virtual std::map<std::string, AsyncServiceInterface*> ( |
150 | ::grpc::ServerBuilder*) { |
151 | return {}; |
152 | } |
153 | |
154 | virtual std::map<std::string, AsyncServiceInterface*> () { |
155 | return extra_services_; |
156 | } |
157 | |
158 | // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec. |
159 | Status ParseChannelSpec(const WorkerCacheFactoryOptions& options, |
160 | GrpcChannelSpec* channel_spec); |
161 | |
162 | // Returns the port to which this server is bound. |
163 | // This method may only be called after `this->Init()` returns successfully. |
164 | int bound_port() const { return bound_port_; } |
165 | |
166 | // Returns hostname. |
167 | const string& host_name() const { return host_name_; } |
168 | |
169 | const ServerDef& server_def() const { return server_def_; } |
170 | GrpcWorker* worker_impl() const { return worker_impl_.get(); } |
171 | GrpcWorkerEnv* grpc_worker_env() const { return grpc_worker_env_.get(); } |
172 | |
173 | private: |
174 | Env* env_; |
175 | |
176 | // The port to which this server is bound. |
177 | int bound_port_ = 0; |
178 | |
179 | // The host name of this server |
180 | string host_name_; |
181 | |
182 | // Guards server configuration, server, and state. |
183 | mutex mu_; |
184 | |
185 | // Represents the current state of the server, which changes as follows: |
186 | // |
187 | // Join() Join() |
188 | // ___ ___ |
189 | // Start() \ / Stop() \ / |
190 | // NEW ---------> STARTED --------> STOPPED |
191 | // \ / |
192 | // \________________________/ |
193 | // Stop(), Join() |
194 | enum State { NEW, STARTED, STOPPED }; |
195 | State state_ TF_GUARDED_BY(mu_); |
196 | |
197 | // Implementation of a TensorFlow master, and RPC polling thread. |
198 | MasterEnv master_env_; |
199 | std::unique_ptr<Master> master_impl_; |
200 | AsyncServiceInterface* master_service_ = nullptr; |
201 | std::unique_ptr<Thread> master_thread_ TF_GUARDED_BY(mu_); |
202 | |
203 | std::map<std::string, AsyncServiceInterface*> ; |
204 | std::vector<std::unique_ptr<Thread>> |
205 | TF_GUARDED_BY(mu_); |
206 | |
207 | // Implementation of a TensorFlow worker, and RPC polling thread. |
208 | WorkerEnv worker_env_; |
209 | std::unique_ptr<const DeviceMgr> owned_device_manager_; |
210 | std::unique_ptr<GrpcWorker> worker_impl_; |
211 | AsyncServiceInterface* worker_service_ = nullptr; |
212 | std::unique_ptr<Thread> worker_thread_ TF_GUARDED_BY(mu_); |
213 | std::unique_ptr<GrpcWorkerEnv> grpc_worker_env_; |
214 | |
215 | // TensorFlow Eager implementation, and RPC polling thread. |
216 | AsyncServiceInterface* eager_service_ = nullptr; |
217 | std::unique_ptr<Thread> eager_thread_ TF_GUARDED_BY(mu_); |
218 | std::shared_ptr<WorkerSession> worker_session_; |
219 | |
220 | // Experimental coordination service implementation, and RPC polling thread. |
221 | AsyncServiceInterface* coordination_service_ = nullptr; |
222 | std::unique_ptr<Thread> coordination_thread_ TF_GUARDED_BY(mu_); |
223 | |
224 | // TensorFlow profiler service implementation. |
225 | std::unique_ptr<grpc::ProfilerService::Service> profiler_service_ = nullptr; |
226 | |
227 | // The overall server configuration. |
228 | ServerDef server_def_ TF_GUARDED_BY(mu_); |
229 | |
230 | std::unique_ptr<::grpc::Server> server_ TF_GUARDED_BY(mu_); |
231 | }; |
232 | |
233 | } // namespace tensorflow |
234 | |
235 | #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ |
236 | |