1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_ |
17 | #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_ |
18 | |
19 | #include <memory> |
20 | |
21 | #include "tensorflow/core/lib/core/status.h" |
22 | #include "tensorflow/core/platform/macros.h" |
23 | #include "tensorflow/core/protobuf/tensorflow_server.pb.h" |
24 | |
25 | namespace tensorflow { |
26 | |
27 | class CoordinationServiceAgent; |
28 | class DeviceMgr; |
29 | class EagerContext; |
30 | class WorkerEnv; |
31 | class MasterEnv; |
32 | |
33 | // This library supports a registration/factory-based mechanism for |
34 | // creating TensorFlow server objects. Each server implementation must |
35 | // have an accompanying implementation of ServerFactory, and create a |
36 | // static "registrar" object that calls `ServerFactory::Register()` |
37 | // with an instance of the factory class. See "rpc/grpc_server_lib.cc" |
38 | // for an example. |
39 | |
40 | // Represents a single TensorFlow server that exports Master and Worker |
41 | // services. |
42 | class ServerInterface { |
43 | public: |
44 | ServerInterface() {} |
45 | virtual ~ServerInterface() {} |
46 | |
47 | // Starts the server running asynchronously. Returns OK on success, otherwise |
48 | // returns an error. |
49 | virtual Status Start() = 0; |
50 | |
51 | // Stops the server asynchronously. Returns OK on success, otherwise returns |
52 | // an error. |
53 | // |
54 | // After calling `Stop()`, the caller may call `Join()` to block until the |
55 | // server has stopped. |
56 | virtual Status Stop() = 0; |
57 | |
58 | // Blocks until the server has stopped. Returns OK on success, otherwise |
59 | // returns an error. |
60 | virtual Status Join() = 0; |
61 | |
62 | // Returns a target string that can be used to connect to this server using |
63 | // `tensorflow::NewSession()`. |
64 | virtual const string target() const = 0; |
65 | |
66 | virtual WorkerEnv* worker_env() = 0; |
67 | virtual MasterEnv* master_env() = 0; |
68 | |
69 | // Update the set of workers that can be reached by the server |
70 | virtual Status UpdateServerDef(const ServerDef& server_def) = 0; |
71 | |
72 | // Functions to operate on service-specific properties. |
73 | // |
74 | // Add master eager context to local eager service in order to handle enqueue |
75 | // requests from remote workers. |
76 | virtual Status AddMasterEagerContextToEagerService( |
77 | const tensorflow::uint64 context_id, EagerContext* context) = 0; |
78 | // Set coordination service agent instance to coordination service RPC handler |
79 | virtual Status SetCoordinationServiceAgentInstance( |
80 | CoordinationServiceAgent* agent) = 0; |
81 | // TODO(hanyangtay): Remove this method once gRPC server clean shutdown is |
82 | // supported. |
83 | virtual Status StopCoordinationService() = 0; |
84 | |
85 | private: |
86 | TF_DISALLOW_COPY_AND_ASSIGN(ServerInterface); |
87 | }; |
88 | |
89 | class ServerFactory { |
90 | public: |
91 | struct Options { |
92 | // Local DeviceMgr to use. |
93 | tensorflow::DeviceMgr* local_device_mgr; |
94 | }; |
95 | // Creates a new server based on the given `server_def`, and stores |
96 | // it in `*out_server`. Returns OK on success, otherwise returns an |
97 | // error. |
98 | virtual Status NewServer(const ServerDef& server_def, const Options& options, |
99 | std::unique_ptr<ServerInterface>* out_server) = 0; |
100 | |
101 | // Returns true if and only if this factory can create a server |
102 | // based on the given `server_def`. |
103 | virtual bool AcceptsOptions(const ServerDef& server_def) = 0; |
104 | |
105 | virtual ~ServerFactory() {} |
106 | |
107 | // For each `ServerFactory` subclass, an instance of that class must |
108 | // be registered by calling this method. |
109 | // |
110 | // The `server_type` must be unique to the server factory. |
111 | static void Register(const string& server_type, ServerFactory* factory); |
112 | |
113 | // Looks up a factory that can create a server based on the given |
114 | // `server_def`, and stores it in `*out_factory`. Returns OK on |
115 | // success, otherwise returns an error. |
116 | static Status GetFactory(const ServerDef& server_def, |
117 | ServerFactory** out_factory); |
118 | }; |
119 | |
120 | // Creates a server based on the given `server_def`, and stores it in |
121 | // `*out_server`. Returns OK on success, otherwise returns an error. |
122 | Status NewServer(const ServerDef& server_def, |
123 | std::unique_ptr<ServerInterface>* out_server); |
124 | Status NewServerWithOptions(const ServerDef& server_def, |
125 | const ServerFactory::Options& options, |
126 | std::unique_ptr<ServerInterface>* out_server); |
127 | |
128 | } // namespace tensorflow |
129 | |
130 | #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_ |
131 | |