1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_
17#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_
18
19#include <memory>
20
21#include "tensorflow/core/lib/core/status.h"
22#include "tensorflow/core/platform/macros.h"
23#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
24
25namespace tensorflow {
26
27class CoordinationServiceAgent;
28class DeviceMgr;
29class EagerContext;
30class WorkerEnv;
31class MasterEnv;
32
33// This library supports a registration/factory-based mechanism for
34// creating TensorFlow server objects. Each server implementation must
35// have an accompanying implementation of ServerFactory, and create a
36// static "registrar" object that calls `ServerFactory::Register()`
37// with an instance of the factory class. See "rpc/grpc_server_lib.cc"
38// for an example.
39
40// Represents a single TensorFlow server that exports Master and Worker
41// services.
42class ServerInterface {
43 public:
44 ServerInterface() {}
45 virtual ~ServerInterface() {}
46
47 // Starts the server running asynchronously. Returns OK on success, otherwise
48 // returns an error.
49 virtual Status Start() = 0;
50
51 // Stops the server asynchronously. Returns OK on success, otherwise returns
52 // an error.
53 //
54 // After calling `Stop()`, the caller may call `Join()` to block until the
55 // server has stopped.
56 virtual Status Stop() = 0;
57
58 // Blocks until the server has stopped. Returns OK on success, otherwise
59 // returns an error.
60 virtual Status Join() = 0;
61
62 // Returns a target string that can be used to connect to this server using
63 // `tensorflow::NewSession()`.
64 virtual const string target() const = 0;
65
66 virtual WorkerEnv* worker_env() = 0;
67 virtual MasterEnv* master_env() = 0;
68
69 // Update the set of workers that can be reached by the server
70 virtual Status UpdateServerDef(const ServerDef& server_def) = 0;
71
72 // Functions to operate on service-specific properties.
73 //
74 // Add master eager context to local eager service in order to handle enqueue
75 // requests from remote workers.
76 virtual Status AddMasterEagerContextToEagerService(
77 const tensorflow::uint64 context_id, EagerContext* context) = 0;
78 // Set coordination service agent instance to coordination service RPC handler
79 virtual Status SetCoordinationServiceAgentInstance(
80 CoordinationServiceAgent* agent) = 0;
81 // TODO(hanyangtay): Remove this method once gRPC server clean shutdown is
82 // supported.
83 virtual Status StopCoordinationService() = 0;
84
85 private:
86 TF_DISALLOW_COPY_AND_ASSIGN(ServerInterface);
87};
88
89class ServerFactory {
90 public:
91 struct Options {
92 // Local DeviceMgr to use.
93 tensorflow::DeviceMgr* local_device_mgr;
94 };
95 // Creates a new server based on the given `server_def`, and stores
96 // it in `*out_server`. Returns OK on success, otherwise returns an
97 // error.
98 virtual Status NewServer(const ServerDef& server_def, const Options& options,
99 std::unique_ptr<ServerInterface>* out_server) = 0;
100
101 // Returns true if and only if this factory can create a server
102 // based on the given `server_def`.
103 virtual bool AcceptsOptions(const ServerDef& server_def) = 0;
104
105 virtual ~ServerFactory() {}
106
107 // For each `ServerFactory` subclass, an instance of that class must
108 // be registered by calling this method.
109 //
110 // The `server_type` must be unique to the server factory.
111 static void Register(const string& server_type, ServerFactory* factory);
112
113 // Looks up a factory that can create a server based on the given
114 // `server_def`, and stores it in `*out_factory`. Returns OK on
115 // success, otherwise returns an error.
116 static Status GetFactory(const ServerDef& server_def,
117 ServerFactory** out_factory);
118};
119
120// Creates a server based on the given `server_def`, and stores it in
121// `*out_server`. Returns OK on success, otherwise returns an error.
122Status NewServer(const ServerDef& server_def,
123 std::unique_ptr<ServerInterface>* out_server);
124Status NewServerWithOptions(const ServerDef& server_def,
125 const ServerFactory::Options& options,
126 std::unique_ptr<ServerInterface>* out_server);
127
128} // namespace tensorflow
129
130#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_
131