1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_
17#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_
18
19#include <functional>
20
21#include "tensorflow/core/distributed_runtime/call_options.h"
22#include "tensorflow/core/distributed_runtime/message_wrappers.h"
23#include "tensorflow/core/lib/core/notification.h"
24#include "tensorflow/core/lib/core/status.h"
25#include "tensorflow/core/platform/types.h"
26#include "tensorflow/core/protobuf/worker.pb.h"
27
28namespace tensorflow {
29
30// Status callback.
31typedef std::function<void(const Status&)> StatusCallback;
32
33// Custom decoder for a response to RecvTensorAsync.
34class TensorResponse;
35
36// Interface for talking with the TensorFlow Worker service.
37class WorkerInterface {
38 public:
39 virtual void GetStatusAsync(CallOptions* opts,
40 const GetStatusRequest* request,
41 GetStatusResponse* response, bool fail_fast,
42 StatusCallback done) = 0;
43
44 virtual void CreateWorkerSessionAsync(
45 const CreateWorkerSessionRequest* request,
46 CreateWorkerSessionResponse* response, StatusCallback done) = 0;
47
48 virtual void DeleteWorkerSessionAsync(
49 CallOptions* opts, const DeleteWorkerSessionRequest* request,
50 DeleteWorkerSessionResponse* response, StatusCallback done) = 0;
51
52 virtual void RegisterGraphAsync(const RegisterGraphRequest* request,
53 RegisterGraphResponse* response,
54 StatusCallback done) = 0;
55
56 virtual void DeregisterGraphAsync(const DeregisterGraphRequest* request,
57 DeregisterGraphResponse* response,
58 StatusCallback done) = 0;
59
60 virtual void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
61 MutableRunGraphResponseWrapper* response,
62 StatusCallback done) = 0;
63
64 virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request,
65 RunGraphResponse* response, StatusCallback done) {
66 RunGraphRequestWrapper* wrapped_request = new ProtoRunGraphRequest(request);
67 MutableRunGraphResponseWrapper* wrapped_response =
68 new NonOwnedProtoRunGraphResponse(response);
69 RunGraphAsync(opts, wrapped_request, wrapped_response,
70 [wrapped_request, wrapped_response,
71 done = std::move(done)](const Status& s) {
72 done(s);
73 delete wrapped_request;
74 delete wrapped_response;
75 });
76 }
77
78 // Returns a request object for use in calls to
79 // `RunGraphAsync()`. Ownership is transferred to the caller.
80 //
81 // The message returned from this method must only be used in a
82 // `RunGraph()` call on the same `WorkerInterface` instance.
83 virtual MutableRunGraphRequestWrapper* CreateRunGraphRequest() {
84 return new MutableProtoRunGraphRequest;
85 }
86
87 // Returns a response object for use in calls to
88 // `RunGraphAsync()`. Ownership is transferred to the caller.
89 //
90 // The message returned from this method must only be used in a
91 // `RunGraph()` call on the same `WorkerInterface` instance.
92 virtual MutableRunGraphResponseWrapper* CreateRunGraphResponse() {
93 return new OwnedProtoRunGraphResponse;
94 }
95
96 virtual void CleanupGraphAsync(const CleanupGraphRequest* request,
97 CleanupGraphResponse* response,
98 StatusCallback done) = 0;
99
100 virtual void CleanupAllAsync(const CleanupAllRequest* request,
101 CleanupAllResponse* response,
102 StatusCallback done) = 0;
103
104 virtual void RecvTensorAsync(CallOptions* opts,
105 const RecvTensorRequest* request,
106 TensorResponse* response,
107 StatusCallback done) = 0;
108
109 virtual void LoggingAsync(const LoggingRequest* request,
110 LoggingResponse* response, StatusCallback done) = 0;
111
112 virtual void TracingAsync(const TracingRequest* request,
113 TracingResponse* response, StatusCallback done) = 0;
114
115 virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
116 RecvBufResponse* response, StatusCallback done) = 0;
117
118 virtual void CompleteGroupAsync(CallOptions* opts,
119 const CompleteGroupRequest* request,
120 CompleteGroupResponse* response,
121 StatusCallback done) = 0;
122
123 virtual void CompleteInstanceAsync(CallOptions* ops,
124 const CompleteInstanceRequest* request,
125 CompleteInstanceResponse* response,
126 StatusCallback done) = 0;
127
128 virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request,
129 GetStepSequenceResponse* response,
130 StatusCallback done) = 0;
131
132 Status GetStatus(const GetStatusRequest* request,
133 GetStatusResponse* response) {
134 Status ret;
135 Notification n;
136 GetStatusAsync(/*opts=*/nullptr, request, response, /*fail_fast=*/true,
137 [&ret, &n](const Status& s) {
138 ret = s;
139 n.Notify();
140 });
141 n.WaitForNotification();
142 return ret;
143 }
144
145 Status CreateWorkerSession(const CreateWorkerSessionRequest* request,
146 CreateWorkerSessionResponse* response) {
147 return CallAndWait(&ME::CreateWorkerSessionAsync, request, response);
148 }
149
150 Status DeleteWorkerSession(const DeleteWorkerSessionRequest* request,
151 DeleteWorkerSessionResponse* response) {
152 return CallAndWaitWithOptions(&ME::DeleteWorkerSessionAsync, request,
153 response);
154 }
155
156 Status RegisterGraph(const RegisterGraphRequest* request,
157 RegisterGraphResponse* response) {
158 return CallAndWait(&ME::RegisterGraphAsync, request, response);
159 }
160
161 Status DeregisterGraph(const DeregisterGraphRequest* request,
162 DeregisterGraphResponse* response) {
163 return CallAndWait(&ME::DeregisterGraphAsync, request, response);
164 }
165
166 Status CleanupGraph(const CleanupGraphRequest* request,
167 CleanupGraphResponse* response) {
168 return CallAndWait(&ME::CleanupGraphAsync, request, response);
169 }
170
171 Status CleanupAll(const CleanupAllRequest* request,
172 CleanupAllResponse* response) {
173 return CallAndWait(&ME::CleanupAllAsync, request, response);
174 }
175
176 Status Logging(const LoggingRequest* request, LoggingResponse* response) {
177 return CallAndWait(&ME::LoggingAsync, request, response);
178 }
179
180 Status Tracing(const TracingRequest* request, TracingResponse* response) {
181 return CallAndWait(&ME::TracingAsync, request, response);
182 }
183
184 Status GetStepSequence(const GetStepSequenceRequest* request,
185 GetStepSequenceResponse* response) {
186 return CallAndWait(&ME::GetStepSequenceAsync, request, response);
187 }
188
189 protected:
190 // Instances of WorkerInterface must be deleted by a call to
191 // WorkerCacheInterface::ReleaseWorker().
192 virtual ~WorkerInterface() {}
193 friend class WorkerCacheInterface;
194
195 // NOTE: This should only be called by implementations of this
196 // interface whose CreateRunGraphResponse() method returns a
197 // proto-based wrappers for the RunGraphResponse message.
198 RunGraphResponse* get_proto_from_wrapper(
199 MutableRunGraphResponseWrapper* wrapper) {
200 return wrapper->get_proto();
201 }
202
203 private:
204 typedef WorkerInterface ME;
205
206 template <typename Method, typename Req, typename Resp>
207 Status CallAndWait(Method func, const Req* req, Resp* resp) {
208 Status ret;
209 Notification n;
210 (this->*func)(req, resp, [&ret, &n](const Status& s) {
211 ret = s;
212 n.Notify();
213 });
214 n.WaitForNotification();
215 return ret;
216 }
217
218 template <typename Method, typename Req, typename Resp>
219 Status CallAndWaitWithOptions(Method func, const Req* req, Resp* resp) {
220 CallOptions call_opts;
221 Status ret;
222 Notification n;
223 (this->*func)(&call_opts, req, resp, [&ret, &n](const Status& s) {
224 ret = s;
225 n.Notify();
226 });
227 n.WaitForNotification();
228 return ret;
229 }
230};
231
232} // namespace tensorflow
233
234#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_
235