1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_ |
17 | #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_ |
18 | |
19 | #include "tensorflow/core/framework/allocator.h" |
20 | #include "tensorflow/core/framework/cost_graph.pb.h" |
21 | #include "tensorflow/core/framework/graph.pb.h" |
22 | #include "tensorflow/core/framework/step_stats.pb.h" |
23 | #include "tensorflow/core/framework/tensor.h" |
24 | #include "tensorflow/core/framework/tensor.pb.h" |
25 | #include "tensorflow/core/framework/versions.pb.h" |
26 | #include "tensorflow/core/protobuf/config.pb.h" |
27 | #include "tensorflow/core/protobuf/master.pb.h" |
28 | #include "tensorflow/core/protobuf/worker.pb.h" |
29 | |
30 | namespace tensorflow { |
31 | |
32 | //////////////////////////////////////////////////////////////////////////////// |
33 | // |
34 | // Wrapper classes for the `MasterService.RunStep` request message. |
35 | // |
36 | // The `RunStepRequest` message can contain potentially large tensor |
37 | // data as part of its `feed` submessages. Here we provide specialized |
38 | // wrappers that avoid copying the tensor data wherever possible. |
39 | // |
40 | // See `RunStepRequest` in tensorflow/core/protobuf/master.proto for the |
41 | // protocol buffer definition. |
42 | // |
43 | //////////////////////////////////////////////////////////////////////////////// |
44 | |
45 | // Abstract interface for an immutable RunStepRequest message. |
46 | // |
47 | // This interface is typically used by server-side components in the |
48 | // TensorFlow master. |
49 | class RunStepRequestWrapper { |
50 | public: |
51 | virtual ~RunStepRequestWrapper() {} |
52 | |
53 | // REQUIRED: session_handle must be returned by a CreateSession call |
54 | // to the same master service. |
55 | virtual const string& session_handle() const = 0; |
56 | |
57 | // Partial run handle (optional). If specified, this will be a partial run |
58 | // execution, run up to the specified fetches. |
59 | virtual const string& partial_run_handle() const = 0; |
60 | |
61 | // Tensors to be fed in the step. Each feed is a named tensor. |
62 | virtual size_t num_feeds() const = 0; |
63 | virtual const string& feed_name(size_t i) const = 0; |
64 | |
65 | // Stores the content of the feed value at index `i` in `tensor`. |
66 | virtual Status FeedValue(size_t i, Tensor* out_tensor) const = 0; |
67 | virtual Status FeedValue(size_t i, TensorProto* out_tensor) const = 0; |
68 | |
69 | // Fetches. A list of tensor names. The caller expects a tensor to |
70 | // be returned for each fetch[i] (see RunStepResponse.tensor). The |
71 | // order of specified fetches does not change the execution order. |
72 | virtual size_t num_fetches() const = 0; |
73 | virtual const string& fetch_name(size_t i) const = 0; |
74 | |
75 | // Target Nodes. A list of node names. The named nodes will be run |
76 | // to but their outputs will not be fetched. |
77 | virtual size_t num_targets() const = 0; |
78 | virtual const string& target_name(size_t i) const = 0; |
79 | |
80 | // Options for the run call. |
81 | virtual const RunOptions& options() const = 0; |
82 | |
83 | // If true then some errors, e.g., execution errors that have long |
84 | // error messages, may return an OK RunStepResponse with the actual |
85 | // error saved in the status_code/status_error_message fields of the |
86 | // response body. This is a workaround since the RPC subsystem may |
87 | // truncate long metadata messages. |
88 | virtual bool store_errors_in_response_body() const = 0; |
89 | |
90 | // Unique identifier for this request. Every RunGraphRequest must have a |
91 | // unique request_id, and retried RunGraphRequests must have the same |
92 | // request_id. If request_id is zero, retry detection is disabled. |
93 | virtual int64_t request_id() const = 0; |
94 | |
95 | // Returns a human-readable representation of this message for debugging. |
96 | virtual string DebugString() const = 0; |
97 | |
98 | // Returns the wrapped data as a protocol buffer message. |
99 | virtual const RunStepRequest& ToProto() const = 0; |
100 | }; |
101 | |
102 | // Abstract interface for a mutable RunStepRequest message. |
103 | // |
104 | // See `RunStepRequestWrapper` above for a description of the fields. |
105 | class MutableRunStepRequestWrapper : public RunStepRequestWrapper { |
106 | public: |
107 | virtual void set_session_handle(const string& handle) = 0; |
108 | virtual void set_partial_run_handle(const string& handle) = 0; |
109 | virtual void add_feed(const string& name, const Tensor& value) = 0; |
110 | virtual void add_fetch(const string& name) = 0; |
111 | virtual void add_target(const string& name) = 0; |
112 | virtual RunOptions* mutable_options() = 0; |
113 | virtual void set_store_errors_in_response_body(bool store_errors) = 0; |
114 | }; |
115 | |
116 | // Specialized (and mutable) wrapper for RunStep requests between a client and |
117 | // master in the same address space. |
118 | class InMemoryRunStepRequest : public MutableRunStepRequestWrapper { |
119 | public: |
120 | // RunStepRequestWrapper methods. |
121 | const string& session_handle() const override; |
122 | const string& partial_run_handle() const override; |
123 | size_t num_feeds() const override; |
124 | const string& feed_name(size_t i) const override; |
125 | Status FeedValue(size_t i, Tensor* out_tensor) const override; |
126 | Status FeedValue(size_t i, TensorProto* out_tensor) const override; |
127 | size_t num_fetches() const override; |
128 | const string& fetch_name(size_t i) const override; |
129 | size_t num_targets() const override; |
130 | const string& target_name(size_t i) const override; |
131 | const RunOptions& options() const override; |
132 | string DebugString() const override; |
133 | const RunStepRequest& ToProto() const override; |
134 | bool store_errors_in_response_body() const override; |
135 | int64_t request_id() const override; |
136 | |
137 | // MutableRunStepRequestWrapper methods. |
138 | void set_session_handle(const string& handle) override; |
139 | void set_partial_run_handle(const string& handle) override; |
140 | void add_feed(const string& name, const Tensor& value) override; |
141 | void add_fetch(const string& name) override; |
142 | void add_target(const string& name) override; |
143 | RunOptions* mutable_options() override; |
144 | void set_store_errors_in_response_body(bool store_errors) override; |
145 | |
146 | private: |
147 | string session_handle_; |
148 | string partial_run_handle_; |
149 | gtl::InlinedVector<std::pair<string, Tensor>, 4> feeds_; |
150 | gtl::InlinedVector<string, 4> fetches_; |
151 | gtl::InlinedVector<string, 4> targets_; |
152 | RunOptions options_; |
153 | bool store_errors_in_response_body_ = false; |
154 | |
155 | // Holds a cached and owned representation of the proto |
156 | // representation of this request, if needed, so that `ToProto()` |
157 | // can return a const RunStepRequest&. |
158 | // NOTE(mrry): Although calls to `ToProto()` on this class are |
159 | // expected to be rare, retaining ownership of the returned message |
160 | // makes it easier to return a reference from the proto-backed |
161 | // representations. |
162 | mutable std::unique_ptr<RunStepRequest> proto_version_; |
163 | }; |
164 | |
165 | // Wrapper for mutable RunStep requests that uses a protobuf message. |
166 | // |
167 | // This wrapper class should be used for RunStep requests between a |
168 | // client and master in different address spaces. |
169 | class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper { |
170 | public: |
171 | // RunStepRequestWrapper methods. |
172 | const string& session_handle() const override; |
173 | const string& partial_run_handle() const override; |
174 | size_t num_feeds() const override; |
175 | const string& feed_name(size_t i) const override; |
176 | Status FeedValue(size_t i, Tensor* out_tensor) const override; |
177 | Status FeedValue(size_t i, TensorProto* out_tensor) const override; |
178 | size_t num_fetches() const override; |
179 | const string& fetch_name(size_t i) const override; |
180 | size_t num_targets() const override; |
181 | const string& target_name(size_t i) const override; |
182 | const RunOptions& options() const override; |
183 | string DebugString() const override; |
184 | const RunStepRequest& ToProto() const override; |
185 | bool store_errors_in_response_body() const override; |
186 | int64_t request_id() const override; |
187 | |
188 | // MutableRunStepRequestWrapper methods. |
189 | void set_session_handle(const string& handle) override; |
190 | void set_partial_run_handle(const string& handle) override; |
191 | void add_feed(const string& name, const Tensor& value) override; |
192 | void add_fetch(const string& name) override; |
193 | void add_target(const string& name) override; |
194 | RunOptions* mutable_options() override; |
195 | void set_store_errors_in_response_body(bool store_errors) override; |
196 | |
197 | private: |
198 | RunStepRequest request_; |
199 | friend class MasterInterface; |
200 | }; |
201 | |
202 | // Wrapper for immutable RunStep requests that use a non-owned |
203 | // protobuf message. |
204 | // |
205 | // This interface is typically used by server-side components in the |
206 | // TensorFlow master, where the incoming message is a (possibly const) |
207 | // `RunStepRequest*`. |
208 | class ProtoRunStepRequest : public RunStepRequestWrapper { |
209 | public: |
210 | ProtoRunStepRequest(const RunStepRequest* request); |
211 | |
212 | // RunStepRequestWrapper methods. |
213 | const string& session_handle() const override; |
214 | const string& partial_run_handle() const override; |
215 | size_t num_feeds() const override; |
216 | const string& feed_name(size_t i) const override; |
217 | Status FeedValue(size_t i, Tensor* out_tensor) const override; |
218 | Status FeedValue(size_t i, TensorProto* out_tensor) const override; |
219 | size_t num_fetches() const override; |
220 | const string& fetch_name(size_t i) const override; |
221 | size_t num_targets() const override; |
222 | const string& target_name(size_t i) const override; |
223 | const RunOptions& options() const override; |
224 | string DebugString() const override; |
225 | const RunStepRequest& ToProto() const override; |
226 | bool store_errors_in_response_body() const override; |
227 | int64_t request_id() const override; |
228 | |
229 | private: |
230 | const RunStepRequest* const request_; // Not owned. |
231 | }; |
232 | |
233 | //////////////////////////////////////////////////////////////////////////////// |
234 | // |
235 | // Wrapper classes for the `WorkerService.RunGraph` request message. |
236 | // |
237 | // The `RunGraphRequest` message can contain potentially large tensor |
238 | // data as part of its `send` submessages. Here we provide specialized |
239 | // wrappers that avoid copying the tensor data wherever possible. |
240 | // |
241 | // See `RunGraphRequest` in tensorflow/core/protobuf/worker.proto for the |
242 | // protocol buffer definition. |
243 | // |
244 | //////////////////////////////////////////////////////////////////////////////// |
245 | |
246 | // Abstract interface for an immutable RunGraphRequest message. |
247 | // |
248 | // This interface is typically used by server-side components in the |
249 | // TensorFlow worker. |
250 | class RunGraphRequestWrapper { |
251 | public: |
252 | virtual ~RunGraphRequestWrapper() {} |
253 | |
254 | // The session handle used to register the graph. If empty, a single global |
255 | // namespace is used. |
256 | virtual const string& session_handle() const = 0; |
257 | |
258 | // Set to true if `CreateWorkerSession` was called for `session_handle`. |
259 | virtual bool create_worker_session_called() const = 0; |
260 | |
261 | // REQUIRED: graph_handle must be returned by a RegisterGraph call |
262 | // to the same WorkerService. |
263 | virtual const string& graph_handle() const = 0; |
264 | |
265 | // A unique ID to distinguish different runs of the same graph. |
266 | // |
267 | // The master generates a global unique `step_id` to distinguish |
268 | // different runs of the graph computation. Subgraphs communicate |
269 | // (e.g., send/recv ops) with each other using `step_id` to |
270 | // distinguish tensors generated by different runs. |
271 | virtual int64_t step_id() const = 0; |
272 | |
273 | // Options for this step. |
274 | virtual const ExecutorOpts& exec_opts() const = 0; |
275 | |
276 | // Sends the tensors in "send" into the graph before the run. |
277 | virtual size_t num_sends() const = 0; |
278 | virtual const string& send_key(size_t i) const = 0; |
279 | virtual Status SendValue(size_t i, Tensor* out_tensor) const = 0; |
280 | |
281 | // Fetches the keys into `RunGraphResponse.recv` after the run. |
282 | virtual size_t num_recvs() const = 0; |
283 | virtual const string& recv_key(size_t i) const = 0; |
284 | |
285 | // True if the RunGraphRequest is a partial run request. |
286 | virtual bool is_partial() const = 0; |
287 | |
288 | // True if this is the last partial run request in a sequence of requests. |
289 | virtual bool is_last_partial_run() const = 0; |
290 | |
291 | // If true then some errors, e.g., execution errors that have long |
292 | // error messages, may return an OK RunStepResponse with the actual |
293 | // error saved in the status_code/status_error_message fields of the |
294 | // response body. This is a workaround since the RPC subsystem may |
295 | // truncate long metadata messages. |
296 | virtual bool store_errors_in_response_body() const = 0; |
297 | |
298 | virtual int64_t request_id() const = 0; |
299 | |
300 | // Returns the wrapped data as a protocol buffer message. |
301 | virtual const RunGraphRequest& ToProto() const = 0; |
302 | }; |
303 | |
304 | // Abstract interface for a mutable RunGraphRequest message. |
305 | // |
306 | // See `RunGraphRequestWrapper` above for a description of the fields. |
307 | class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper { |
308 | public: |
309 | virtual void set_session_handle(const string& handle) = 0; |
310 | virtual void set_create_worker_session_called(bool called) = 0; |
311 | virtual void set_graph_handle(const string& handle) = 0; |
312 | virtual void set_step_id(int64_t step_id) = 0; |
313 | virtual ExecutorOpts* mutable_exec_opts() = 0; |
314 | |
315 | // Stores the i^{th} feed value in `run_step_request` in this |
316 | // request with the given `send_key`. |
317 | virtual Status AddSendFromRunStepRequest( |
318 | const RunStepRequestWrapper& run_step_request, size_t i, |
319 | const string& send_key) = 0; |
320 | virtual Status AddSendFromRunCallableRequest( |
321 | const RunCallableRequest& run_callable_request, size_t i, |
322 | const string& send_key) = 0; |
323 | |
324 | virtual void add_recv_key(const string& recv_key) = 0; |
325 | virtual void set_is_partial(bool is_partial) = 0; |
326 | virtual void set_is_last_partial_run(bool is_last_partial_run) = 0; |
327 | virtual void set_store_errors_in_response_body(bool store_errors) = 0; |
328 | virtual void set_request_id(int64_t request_id) = 0; |
329 | }; |
330 | |
331 | class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { |
332 | public: |
333 | // RunGraphRequestWrapper methods. |
334 | const string& session_handle() const override; |
335 | const string& graph_handle() const override; |
336 | bool create_worker_session_called() const override; |
337 | int64_t step_id() const override; |
338 | const ExecutorOpts& exec_opts() const override; |
339 | size_t num_sends() const override; |
340 | const string& send_key(size_t i) const override; |
341 | Status SendValue(size_t i, Tensor* out_tensor) const override; |
342 | size_t num_recvs() const override; |
343 | const string& recv_key(size_t i) const override; |
344 | bool is_partial() const override; |
345 | bool is_last_partial_run() const override; |
346 | const RunGraphRequest& ToProto() const override; |
347 | bool store_errors_in_response_body() const override; |
348 | int64_t request_id() const override; |
349 | |
350 | // MutableRunGraphRequestWrapper methods. |
351 | void set_session_handle(const string& handle) override; |
352 | void set_create_worker_session_called(bool called) override; |
353 | void set_graph_handle(const string& handle) override; |
354 | void set_step_id(int64_t step_id) override; |
355 | ExecutorOpts* mutable_exec_opts() override; |
356 | Status AddSendFromRunStepRequest( |
357 | const RunStepRequestWrapper& run_step_request, size_t i, |
358 | const string& send_key) override; |
359 | Status AddSendFromRunCallableRequest( |
360 | const RunCallableRequest& run_callable_request, size_t i, |
361 | const string& send_key) override; |
362 | void add_recv_key(const string& recv_key) override; |
363 | void set_is_partial(bool is_partial) override; |
364 | void set_is_last_partial_run(bool is_last_partial_run) override; |
365 | void set_store_errors_in_response_body(bool store_errors) override; |
366 | void set_request_id(int64_t request_id) override; |
367 | |
368 | private: |
369 | string session_handle_; |
370 | bool create_worker_session_called_ = false; |
371 | string graph_handle_; |
372 | int64_t step_id_; |
373 | ExecutorOpts exec_opts_; |
374 | gtl::InlinedVector<std::pair<string, Tensor>, 4> sends_; |
375 | gtl::InlinedVector<string, 4> recvs_; |
376 | bool is_partial_ = false; |
377 | bool is_last_partial_run_ = false; |
378 | bool store_errors_in_response_body_ = false; |
379 | int64_t request_id_ = 0; |
380 | |
381 | // Holds a cached and owned representation of the proto |
382 | // representation of this request, if needed, so that `ToProto()` |
383 | // can return a const RunGraphRequest&. |
384 | // NOTE(mrry): Although calls to `ToProto()` on this class are |
385 | // expected to be rare, retaining ownership of the returned message |
386 | // makes it easier to return a reference from the proto-backed |
387 | // representations. |
388 | mutable std::unique_ptr<RunGraphRequest> proto_version_; |
389 | }; |
390 | |
391 | class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper { |
392 | public: |
393 | // RunGraphRequestWrapper methods. |
394 | const string& session_handle() const override; |
395 | bool create_worker_session_called() const override; |
396 | const string& graph_handle() const override; |
397 | int64_t step_id() const override; |
398 | const ExecutorOpts& exec_opts() const override; |
399 | size_t num_sends() const override; |
400 | const string& send_key(size_t i) const override; |
401 | Status SendValue(size_t i, Tensor* out_tensor) const override; |
402 | size_t num_recvs() const override; |
403 | const string& recv_key(size_t i) const override; |
404 | bool is_partial() const override; |
405 | bool is_last_partial_run() const override; |
406 | bool store_errors_in_response_body() const override; |
407 | int64_t request_id() const override; |
408 | const RunGraphRequest& ToProto() const override; |
409 | |
410 | // MutableRunGraphRequestWrapper methods. |
411 | void set_session_handle(const string& handle) override; |
412 | void set_create_worker_session_called(bool called) override; |
413 | void set_graph_handle(const string& handle) override; |
414 | void set_step_id(int64_t step_id) override; |
415 | ExecutorOpts* mutable_exec_opts() override; |
416 | Status AddSendFromRunStepRequest( |
417 | const RunStepRequestWrapper& run_step_request, size_t i, |
418 | const string& send_key) override; |
419 | Status AddSendFromRunCallableRequest( |
420 | const RunCallableRequest& run_callable_request, size_t i, |
421 | const string& send_key) override; |
422 | void add_recv_key(const string& recv_key) override; |
423 | void set_is_partial(bool is_partial) override; |
424 | void set_is_last_partial_run(bool is_last_partial_run) override; |
425 | void set_store_errors_in_response_body(bool store_errors) override; |
426 | void set_request_id(int64_t request_id) override; |
427 | |
428 | private: |
429 | RunGraphRequest request_; |
430 | }; |
431 | |
432 | class ProtoRunGraphRequest : public RunGraphRequestWrapper { |
433 | public: |
434 | ProtoRunGraphRequest(const RunGraphRequest* request); |
435 | |
436 | // RunGraphRequestWrapper methods. |
437 | const string& session_handle() const override; |
438 | bool create_worker_session_called() const override; |
439 | const string& graph_handle() const override; |
440 | int64_t step_id() const override; |
441 | const ExecutorOpts& exec_opts() const override; |
442 | size_t num_sends() const override; |
443 | const string& send_key(size_t i) const override; |
444 | Status SendValue(size_t i, Tensor* out_tensor) const override; |
445 | size_t num_recvs() const override; |
446 | const string& recv_key(size_t i) const override; |
447 | bool is_partial() const override; |
448 | bool is_last_partial_run() const override; |
449 | bool store_errors_in_response_body() const override; |
450 | int64_t request_id() const override; |
451 | const RunGraphRequest& ToProto() const override; |
452 | |
453 | private: |
454 | const RunGraphRequest* const request_; // Not owned. |
455 | }; |
456 | |
457 | //////////////////////////////////////////////////////////////////////////////// |
458 | // |
459 | // Wrapper classes for the `WorkerService.RunGraph` response message. |
460 | // |
461 | // The `RunGraphResponse` message can contain potentially large tensor |
462 | // data as part of its `recv` submessages. Here we provide specialized |
463 | // wrappers that avoid copying the tensor data wherever possible. |
464 | // |
465 | // See `RunGraphResponse` in tensorflow/core/protobuf/worker.proto for the |
466 | // protocol buffer definition. |
467 | // |
468 | //////////////////////////////////////////////////////////////////////////////// |
469 | |
470 | // Abstract interface for a mutable RunGraphResponse message. |
471 | // |
472 | // Note that there is no corresponding (immutable) |
473 | // RunGraphResponseWrapper class, because the RunGraphResponse object |
474 | // is always used as a mutable pointer. |
475 | class MutableRunGraphResponseWrapper { |
476 | public: |
477 | virtual ~MutableRunGraphResponseWrapper() {} |
478 | |
479 | // A list of tensors corresponding to those requested by |
480 | // `RunGraphRequest.recv_key`. |
481 | virtual size_t num_recvs() const = 0; |
482 | virtual const string& recv_key(size_t i) const = 0; |
483 | // NOTE: The following methods may perform a destructive read, for |
484 | // efficiency. |
485 | virtual Status RecvValue(size_t i, TensorProto* out_tensor) = 0; |
486 | virtual Status RecvValue(size_t i, Tensor* out_tensor) = 0; |
487 | virtual void AddRecv(const string& key, const Tensor& value) = 0; |
488 | |
489 | // Submessages that store performance statistics about the subgraph |
490 | // execution, if necessary. |
491 | virtual StepStats* mutable_step_stats() = 0; |
492 | virtual CostGraphDef* mutable_cost_graph() = 0; |
493 | virtual size_t num_partition_graphs() const = 0; |
494 | virtual GraphDef* mutable_partition_graph(size_t i) = 0; |
495 | virtual void AddPartitionGraph(const GraphDef& partition_graph) = 0; |
496 | |
497 | // Returned status if requested. |
498 | virtual Status status() const = 0; |
499 | virtual errors::Code status_code() const = 0; |
500 | virtual const string& status_error_message() const = 0; |
501 | virtual void set_status(const Status& status) = 0; |
502 | |
503 | protected: |
504 | // Returns a mutable protobuf message that represents the contents of |
505 | // this wrapper, for passing to an RPC subsystem that will populate |
506 | // the message. |
507 | // |
508 | // NOTE: Only `WorkerInterface` subclasses may call this method. The |
509 | // `InMemoryRunGraphResponse` subclass does not implement this |
510 | // method, and attempts to call it will fail with a fatal |
511 | // error. However, as long as callers always call |
512 | // `WorkerInterface::RunGraphAsync()` with a wrapper object returned |
513 | // from `WorkerInterface::CreateRunGraphResponse()` called on the |
514 | // *same* WorkerInterface object, this error will never trigger. |
515 | virtual RunGraphResponse* get_proto() = 0; |
516 | friend class WorkerInterface; |
517 | }; |
518 | |
519 | class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper { |
520 | public: |
521 | // MutableRunGraphResponseWrapper methods. |
522 | size_t num_recvs() const override; |
523 | const string& recv_key(size_t i) const override; |
524 | Status RecvValue(size_t i, TensorProto* out_tensor) override; |
525 | Status RecvValue(size_t i, Tensor* out_tensor) override; |
526 | void AddRecv(const string& key, const Tensor& value) override; |
527 | StepStats* mutable_step_stats() override; |
528 | CostGraphDef* mutable_cost_graph() override; |
529 | size_t num_partition_graphs() const override; |
530 | GraphDef* mutable_partition_graph(size_t i) override; |
531 | void AddPartitionGraph(const GraphDef& partition_graph) override; |
532 | Status status() const override; |
533 | errors::Code status_code() const override; |
534 | const string& status_error_message() const override; |
535 | void set_status(const Status& status) override; |
536 | |
537 | protected: |
538 | // NOTE: This method is not implemented. See |
539 | // MutableRunGraphResponseWrapper for an explanation. |
540 | RunGraphResponse* get_proto() override; |
541 | |
542 | private: |
543 | gtl::InlinedVector<std::pair<string, Tensor>, 4> recvs_; |
544 | StepStats step_stats_; |
545 | CostGraphDef cost_graph_; |
546 | std::vector<GraphDef> partition_graphs_; |
547 | // Store the code and message separately so that they can be updated |
548 | // independently by setters. |
549 | Status status_; |
550 | }; |
551 | |
552 | // Proto-based message wrapper for use on the client side of the RunGraph RPC. |
553 | class OwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper { |
554 | public: |
555 | // MutableRunGraphResponseWrapper methods. |
556 | size_t num_recvs() const override; |
557 | const string& recv_key(size_t i) const override; |
558 | Status RecvValue(size_t i, TensorProto* out_tensor) override; |
559 | Status RecvValue(size_t i, Tensor* out_tensor) override; |
560 | void AddRecv(const string& key, const Tensor& value) override; |
561 | StepStats* mutable_step_stats() override; |
562 | CostGraphDef* mutable_cost_graph() override; |
563 | size_t num_partition_graphs() const override; |
564 | GraphDef* mutable_partition_graph(size_t i) override; |
565 | void AddPartitionGraph(const GraphDef& partition_graph) override; |
566 | Status status() const override; |
567 | errors::Code status_code() const override; |
568 | const string& status_error_message() const override; |
569 | void set_status(const Status& status) override; |
570 | |
571 | protected: |
572 | RunGraphResponse* get_proto() override; |
573 | |
574 | private: |
575 | RunGraphResponse response_; |
576 | }; |
577 | |
578 | // Proto-based message wrapper for use on the server side of the RunGraph RPC. |
579 | class NonOwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper { |
580 | public: |
581 | NonOwnedProtoRunGraphResponse(RunGraphResponse* response); |
582 | |
583 | // MutableRunGraphResponseWrapper methods. |
584 | size_t num_recvs() const override; |
585 | const string& recv_key(size_t i) const override; |
586 | Status RecvValue(size_t i, TensorProto* out_tensor) override; |
587 | Status RecvValue(size_t i, Tensor* out_tensor) override; |
588 | void AddRecv(const string& key, const Tensor& value) override; |
589 | StepStats* mutable_step_stats() override; |
590 | CostGraphDef* mutable_cost_graph() override; |
591 | size_t num_partition_graphs() const override; |
592 | GraphDef* mutable_partition_graph(size_t i) override; |
593 | void AddPartitionGraph(const GraphDef& partition_graph) override; |
594 | Status status() const override; |
595 | errors::Code status_code() const override; |
596 | const string& status_error_message() const override; |
597 | void set_status(const Status& status) override; |
598 | |
599 | protected: |
600 | RunGraphResponse* get_proto() override; |
601 | |
602 | private: |
603 | RunGraphResponse* const response_; |
604 | }; |
605 | |
606 | //////////////////////////////////////////////////////////////////////////////// |
607 | // |
608 | // Wrapper classes for the `MasterService.RunStep` response message. |
609 | // |
610 | // The `RunStepResponse` message can contain potentially large tensor |
611 | // data as part of its `tensor` submessages. Here we provide specialized |
612 | // wrappers that avoid copying the tensor data wherever possible. |
613 | // |
614 | // See `RunStepResponse` in tensorflow/core/protobuf/master.proto for the |
615 | // protocol buffer definition. |
616 | // |
617 | //////////////////////////////////////////////////////////////////////////////// |
618 | |
619 | // Abstract interface for a mutable RunStepResponse message. |
620 | // |
621 | // Note that there is no corresponding (immutable) |
622 | // RunStepResponseWrapper class, because the RunStepResponse object is |
623 | // always used as a mutable pointer. |
624 | class MutableRunStepResponseWrapper { |
625 | public: |
626 | virtual ~MutableRunStepResponseWrapper(); |
627 | |
628 | // The values of the tensors whose fetching was requested in the |
629 | // RunStep call. |
630 | // |
631 | // NOTE: The order of the returned tensors may or may not match |
632 | // the fetch order specified in RunStepRequest. |
633 | virtual size_t num_tensors() const = 0; |
634 | virtual const string& tensor_name(size_t i) const = 0; |
635 | virtual Status TensorValue(size_t i, Tensor* out_tensor) const = 0; |
636 | |
637 | // Stores the i^{th} recv value in `run_graph_response` in this |
638 | // response with the given `name`. |
639 | virtual Status AddTensorFromRunGraphResponse( |
640 | const string& name, MutableRunGraphResponseWrapper* run_graph_response, |
641 | size_t i) = 0; |
642 | |
643 | // Returned metadata if requested in the options. |
644 | virtual const RunMetadata& metadata() const = 0; |
645 | virtual RunMetadata* mutable_metadata() = 0; |
646 | |
647 | // Returned status if requested. |
648 | virtual Status status() const = 0; |
649 | virtual errors::Code status_code() const = 0; |
650 | virtual const string& status_error_message() const = 0; |
651 | virtual void set_status(const Status& status) = 0; |
652 | |
653 | protected: |
654 | // Returns a mutable protobuf message that represents the contents of |
655 | // this wrapper, for passing to an RPC subsystem that will populate |
656 | // the message. |
657 | // |
658 | // NOTE: Only `MasterInterface` subclasses may call this method. The |
659 | // `InMemoryRunStepResponse` subclass does not implement this |
660 | // method, and attempts to call it will fail with a fatal |
661 | // error. However, as long as callers always call |
662 | // `MasterInterface::RunStep()` with a wrapper object returned |
663 | // from `MasterInterface::CreateRunStepResponse()` called on the |
664 | // *same* MasterInterface object, this error will never trigger. |
665 | virtual RunStepResponse* get_proto() = 0; |
666 | friend class MasterInterface; |
667 | }; |
668 | |
669 | class InMemoryRunStepResponse : public MutableRunStepResponseWrapper { |
670 | public: |
671 | // MutableRunStepResponseWrapper methods. |
672 | size_t num_tensors() const override; |
673 | const string& tensor_name(size_t i) const override; |
674 | Status TensorValue(size_t i, Tensor* out_tensor) const override; |
675 | Status AddTensorFromRunGraphResponse( |
676 | const string& name, MutableRunGraphResponseWrapper* run_graph_response, |
677 | size_t i) override; |
678 | const RunMetadata& metadata() const override; |
679 | RunMetadata* mutable_metadata() override; |
680 | Status status() const override; |
681 | errors::Code status_code() const override; |
682 | const string& status_error_message() const override; |
683 | void set_status(const Status& status) override; |
684 | |
685 | protected: |
686 | // NOTE: This method is not implemented. See |
687 | // MutableRunGraphResponseWrapper for an explanation. |
688 | RunStepResponse* get_proto() override; |
689 | |
690 | private: |
691 | gtl::InlinedVector<std::pair<string, Tensor>, 4> tensors_; |
692 | RunMetadata metadata_; |
693 | // Store the code and message separately so that they can be updated |
694 | // independently by setters. |
695 | Status status_; |
696 | }; |
697 | |
698 | // Proto-based message wrapper for use on the client side of the RunStep RPC. |
699 | class OwnedProtoRunStepResponse : public MutableRunStepResponseWrapper { |
700 | public: |
701 | // MutableRunStepResponseWrapper methods. |
702 | size_t num_tensors() const override; |
703 | const string& tensor_name(size_t i) const override; |
704 | Status TensorValue(size_t i, Tensor* out_tensor) const override; |
705 | Status AddTensorFromRunGraphResponse( |
706 | const string& name, MutableRunGraphResponseWrapper* run_graph_response, |
707 | size_t i) override; |
708 | const RunMetadata& metadata() const override; |
709 | RunMetadata* mutable_metadata() override; |
710 | Status status() const override; |
711 | errors::Code status_code() const override; |
712 | const string& status_error_message() const override; |
713 | void set_status(const Status& status) override; |
714 | |
715 | protected: |
716 | RunStepResponse* get_proto() override; |
717 | |
718 | private: |
719 | RunStepResponse response_; |
720 | }; |
721 | |
722 | // Proto-based message wrapper for use on the server side of the RunStep RPC. |
723 | class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper { |
724 | public: |
725 | NonOwnedProtoRunStepResponse(RunStepResponse* response); |
726 | |
727 | // MutableRunStepResponseWrapper methods. |
728 | size_t num_tensors() const override; |
729 | const string& tensor_name(size_t i) const override; |
730 | Status TensorValue(size_t i, Tensor* out_tensor) const override; |
731 | Status AddTensorFromRunGraphResponse( |
732 | const string& name, MutableRunGraphResponseWrapper* run_graph_response, |
733 | size_t i) override; |
734 | const RunMetadata& metadata() const override; |
735 | RunMetadata* mutable_metadata() override; |
736 | Status status() const override; |
737 | errors::Code status_code() const override; |
738 | const string& status_error_message() const override; |
739 | void set_status(const Status& status) override; |
740 | |
741 | protected: |
742 | RunStepResponse* get_proto() override; |
743 | |
744 | private: |
745 | RunStepResponse* response_; // Not owned. |
746 | }; |
747 | |
748 | bool ParseTensorProtoToTensor(const TensorProto& tensor_proto, |
749 | Tensor* out_tensor); |
750 | |
751 | } // namespace tensorflow |
752 | |
753 | #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_ |
754 | |