1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
17#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
18
19// GrpcServer manages the lifecycle of an Eager, Worker and Master service.
20
21#include <memory>
22
23#include "grpcpp/grpcpp.h"
24#include "grpcpp/security/credentials.h"
25#include "tensorflow/core/common_runtime/eager/context.h"
26#include "tensorflow/core/common_runtime/process_util.h"
27#include "tensorflow/core/common_runtime/stats_publisher_interface.h"
28#include "tensorflow/core/distributed_runtime/master_env.h"
29#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
30#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
31#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
32#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
33#include "tensorflow/core/distributed_runtime/server_lib.h"
34#include "tensorflow/core/distributed_runtime/session_mgr.h"
35#include "tensorflow/core/distributed_runtime/worker_env.h"
36#include "tensorflow/core/framework/collective.h"
37#include "tensorflow/core/framework/op.h"
38#include "tensorflow/core/platform/env.h"
39#include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
40
41namespace tensorflow {
42
43class GrpcWorker;
44class Master;
45
46// function that creates a RendezvousMgr.
47typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)>
48 RendezvousMgrCreationFunction;
49
50// function that creates a CollectiveExecutorMgr.
51typedef std::function<CollectiveExecutorMgrInterface*(
52 const ConfigProto&, const WorkerEnv*, WorkerCacheInterface*)>
53 CollectiveMgrCreationFunction;
54
55// function that registers a service to the server. The service needs to
56// be registered before builder.BuildAndStart().
57typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)>
58 ServiceInitFunction;
59
60// function that creates a grpc based worker implementation.
61typedef std::function<std::unique_ptr<GrpcWorker>(WorkerEnv*,
62 const ConfigProto& config)>
63 WorkerCreationFunction;
64
65struct GrpcServerOptions {
66 ServiceInitFunction service_func = nullptr;
67 RendezvousMgrCreationFunction rendezvous_mgr_func = nullptr;
68 CollectiveMgrCreationFunction collective_mgr_func = nullptr;
69 WorkerCreationFunction worker_func = nullptr;
70 StatsPublisherFactory stats_factory = CreateNoOpStatsPublisher;
71 GrpcWorkerServiceOptions worker_service_options;
72 DeviceMgr* local_device_mgr = nullptr;
73};
74
75class GrpcServer : public ServerInterface {
76 protected:
77 GrpcServer(const ServerDef& server_def, Env* env);
78 GrpcServer(const ServerDef& server_def, DeviceMgr* local_device_mgr,
79 Env* env);
80 // Allow children classes to override this and provide custom args to the
81 // server before it is constructed. Default behavior is to do nothing.
82 // requested_port provides the port requested by caller as bound_port() is
83 // not available till BuildAndStart has been called.
84 virtual void MaybeMutateBuilder(::grpc::ServerBuilder* builder,
85 int requested_port) {}
86
87 public:
88 static Status Create(const ServerDef& server_def, Env* env,
89 std::unique_ptr<ServerInterface>* out_server);
90 static Status Create(const ServerDef& server_def, Env* env,
91 std::unique_ptr<GrpcServer>* out_server);
92 // Reuse the local_device_mgr.
93 static Status Create(const ServerDef& server_def, Env* env,
94 DeviceMgr* local_device_mgr,
95 std::unique_ptr<ServerInterface>* out_server);
96
97 // Destruction is only supported in the factory method. Clean
98 // shutdown is not currently implemented for this server type.
99 virtual ~GrpcServer();
100
101 // Implementations of ServerInterface methods.
102 Status Start() override;
103 Status Stop() override;
104 Status Join() override;
105 const string target() const override;
106
107 WorkerEnv* worker_env() override { return &worker_env_; }
108 MasterEnv* master_env() override { return &master_env_; }
109
110 // Add master eager context to local eager service in order to handle enqueue
111 // requests from remote workers.
112 Status AddMasterEagerContextToEagerService(
113 const tensorflow::uint64 context_id,
114 tensorflow::EagerContext* context) override;
115 // Update the set of workers that can be reached by the GRPC server
116 Status UpdateServerDef(const ServerDef& server_def) override;
117 // Pass coordination service agent instance to server's RPC handler
118 Status SetCoordinationServiceAgentInstance(
119 CoordinationServiceAgent* agent) override;
120 // TODO(hanyangtay): Remove this method once gRPC server clean shutdown is
121 // supported.
122 Status StopCoordinationService() override;
123
124 protected:
125 virtual Status GetHostAndPort(const ServerDef& server_def, string* host_name,
126 int* port) const;
127 Status Init(const GrpcServerOptions& opts = GrpcServerOptions());
128
129 // A subclass can override this method to support secure credentials.
130 virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials(
131 const ServerDef& server_def) const;
132
133 virtual ChannelCreationFunction GetChannelCreationFunction() const;
134
135 virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env);
136
137 // Creates a WorkerCacheInterface for a session.
138 virtual Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
139 WorkerCacheInterface** worker_cache);
140
141 // Override to return extra services to be brought up and managed along with
142 // the standard {master, worker, eager} services. The map key is an aribtrary
143 // string and the value is a pointer to the service to be brought up.
144 // Ownership of the pointer is transferred to GrpcServer after this call
145 // returns, and the service will be destroyed during the destruction of
146 // GrpcServer. Each service will have its HandleRPCsLoop called in a separate
147 // thread. An example usage would be to add a RDMA based partial worker
148 // service to offload tensor and data buffer transfers.
149 virtual std::map<std::string, AsyncServiceInterface*> ExtraServices(
150 ::grpc::ServerBuilder*) {
151 return {};
152 }
153
154 virtual std::map<std::string, AsyncServiceInterface*> GetExtraServices() {
155 return extra_services_;
156 }
157
158 // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec.
159 Status ParseChannelSpec(const WorkerCacheFactoryOptions& options,
160 GrpcChannelSpec* channel_spec);
161
162 // Returns the port to which this server is bound.
163 // This method may only be called after `this->Init()` returns successfully.
164 int bound_port() const { return bound_port_; }
165
166 // Returns hostname.
167 const string& host_name() const { return host_name_; }
168
169 const ServerDef& server_def() const { return server_def_; }
170 GrpcWorker* worker_impl() const { return worker_impl_.get(); }
171 GrpcWorkerEnv* grpc_worker_env() const { return grpc_worker_env_.get(); }
172
173 private:
174 Env* env_;
175
176 // The port to which this server is bound.
177 int bound_port_ = 0;
178
179 // The host name of this server
180 string host_name_;
181
182 // Guards server configuration, server, and state.
183 mutex mu_;
184
185 // Represents the current state of the server, which changes as follows:
186 //
187 // Join() Join()
188 // ___ ___
189 // Start() \ / Stop() \ /
190 // NEW ---------> STARTED --------> STOPPED
191 // \ /
192 // \________________________/
193 // Stop(), Join()
194 enum State { NEW, STARTED, STOPPED };
195 State state_ TF_GUARDED_BY(mu_);
196
197 // Implementation of a TensorFlow master, and RPC polling thread.
198 MasterEnv master_env_;
199 std::unique_ptr<Master> master_impl_;
200 AsyncServiceInterface* master_service_ = nullptr;
201 std::unique_ptr<Thread> master_thread_ TF_GUARDED_BY(mu_);
202
203 std::map<std::string, AsyncServiceInterface*> extra_services_;
204 std::vector<std::unique_ptr<Thread>> extra_service_threads_
205 TF_GUARDED_BY(mu_);
206
207 // Implementation of a TensorFlow worker, and RPC polling thread.
208 WorkerEnv worker_env_;
209 std::unique_ptr<const DeviceMgr> owned_device_manager_;
210 std::unique_ptr<GrpcWorker> worker_impl_;
211 AsyncServiceInterface* worker_service_ = nullptr;
212 std::unique_ptr<Thread> worker_thread_ TF_GUARDED_BY(mu_);
213 std::unique_ptr<GrpcWorkerEnv> grpc_worker_env_;
214
215 // TensorFlow Eager implementation, and RPC polling thread.
216 AsyncServiceInterface* eager_service_ = nullptr;
217 std::unique_ptr<Thread> eager_thread_ TF_GUARDED_BY(mu_);
218 std::shared_ptr<WorkerSession> worker_session_;
219
220 // Experimental coordination service implementation, and RPC polling thread.
221 AsyncServiceInterface* coordination_service_ = nullptr;
222 std::unique_ptr<Thread> coordination_thread_ TF_GUARDED_BY(mu_);
223
224 // TensorFlow profiler service implementation.
225 std::unique_ptr<grpc::ProfilerService::Service> profiler_service_ = nullptr;
226
227 // The overall server configuration.
228 ServerDef server_def_ TF_GUARDED_BY(mu_);
229
230 std::unique_ptr<::grpc::Server> server_ TF_GUARDED_BY(mu_);
231};
232
233} // namespace tensorflow
234
235#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
236