1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
17#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
18
19#include "tensorflow/core/framework/allocator.h"
20#include "tensorflow/core/framework/cost_graph.pb.h"
21#include "tensorflow/core/framework/graph.pb.h"
22#include "tensorflow/core/framework/step_stats.pb.h"
23#include "tensorflow/core/framework/tensor.h"
24#include "tensorflow/core/framework/tensor.pb.h"
25#include "tensorflow/core/framework/versions.pb.h"
26#include "tensorflow/core/protobuf/config.pb.h"
27#include "tensorflow/core/protobuf/master.pb.h"
28#include "tensorflow/core/protobuf/worker.pb.h"
29
30namespace tensorflow {
31
32////////////////////////////////////////////////////////////////////////////////
33//
34// Wrapper classes for the `MasterService.RunStep` request message.
35//
36// The `RunStepRequest` message can contain potentially large tensor
37// data as part of its `feed` submessages. Here we provide specialized
38// wrappers that avoid copying the tensor data wherever possible.
39//
40// See `RunStepRequest` in tensorflow/core/protobuf/master.proto for the
41// protocol buffer definition.
42//
43////////////////////////////////////////////////////////////////////////////////
44
45// Abstract interface for an immutable RunStepRequest message.
46//
47// This interface is typically used by server-side components in the
48// TensorFlow master.
49class RunStepRequestWrapper {
50 public:
51 virtual ~RunStepRequestWrapper() {}
52
53 // REQUIRED: session_handle must be returned by a CreateSession call
54 // to the same master service.
55 virtual const string& session_handle() const = 0;
56
57 // Partial run handle (optional). If specified, this will be a partial run
58 // execution, run up to the specified fetches.
59 virtual const string& partial_run_handle() const = 0;
60
61 // Tensors to be fed in the step. Each feed is a named tensor.
62 virtual size_t num_feeds() const = 0;
63 virtual const string& feed_name(size_t i) const = 0;
64
65 // Stores the content of the feed value at index `i` in `tensor`.
66 virtual Status FeedValue(size_t i, Tensor* out_tensor) const = 0;
67 virtual Status FeedValue(size_t i, TensorProto* out_tensor) const = 0;
68
69 // Fetches. A list of tensor names. The caller expects a tensor to
70 // be returned for each fetch[i] (see RunStepResponse.tensor). The
71 // order of specified fetches does not change the execution order.
72 virtual size_t num_fetches() const = 0;
73 virtual const string& fetch_name(size_t i) const = 0;
74
75 // Target Nodes. A list of node names. The named nodes will be run
76 // to but their outputs will not be fetched.
77 virtual size_t num_targets() const = 0;
78 virtual const string& target_name(size_t i) const = 0;
79
80 // Options for the run call.
81 virtual const RunOptions& options() const = 0;
82
83 // If true then some errors, e.g., execution errors that have long
84 // error messages, may return an OK RunStepResponse with the actual
85 // error saved in the status_code/status_error_message fields of the
86 // response body. This is a workaround since the RPC subsystem may
87 // truncate long metadata messages.
88 virtual bool store_errors_in_response_body() const = 0;
89
90 // Unique identifier for this request. Every RunGraphRequest must have a
91 // unique request_id, and retried RunGraphRequests must have the same
92 // request_id. If request_id is zero, retry detection is disabled.
93 virtual int64_t request_id() const = 0;
94
95 // Returns a human-readable representation of this message for debugging.
96 virtual string DebugString() const = 0;
97
98 // Returns the wrapped data as a protocol buffer message.
99 virtual const RunStepRequest& ToProto() const = 0;
100};
101
102// Abstract interface for a mutable RunStepRequest message.
103//
104// See `RunStepRequestWrapper` above for a description of the fields.
105class MutableRunStepRequestWrapper : public RunStepRequestWrapper {
106 public:
107 virtual void set_session_handle(const string& handle) = 0;
108 virtual void set_partial_run_handle(const string& handle) = 0;
109 virtual void add_feed(const string& name, const Tensor& value) = 0;
110 virtual void add_fetch(const string& name) = 0;
111 virtual void add_target(const string& name) = 0;
112 virtual RunOptions* mutable_options() = 0;
113 virtual void set_store_errors_in_response_body(bool store_errors) = 0;
114};
115
116// Specialized (and mutable) wrapper for RunStep requests between a client and
117// master in the same address space.
118class InMemoryRunStepRequest : public MutableRunStepRequestWrapper {
119 public:
120 // RunStepRequestWrapper methods.
121 const string& session_handle() const override;
122 const string& partial_run_handle() const override;
123 size_t num_feeds() const override;
124 const string& feed_name(size_t i) const override;
125 Status FeedValue(size_t i, Tensor* out_tensor) const override;
126 Status FeedValue(size_t i, TensorProto* out_tensor) const override;
127 size_t num_fetches() const override;
128 const string& fetch_name(size_t i) const override;
129 size_t num_targets() const override;
130 const string& target_name(size_t i) const override;
131 const RunOptions& options() const override;
132 string DebugString() const override;
133 const RunStepRequest& ToProto() const override;
134 bool store_errors_in_response_body() const override;
135 int64_t request_id() const override;
136
137 // MutableRunStepRequestWrapper methods.
138 void set_session_handle(const string& handle) override;
139 void set_partial_run_handle(const string& handle) override;
140 void add_feed(const string& name, const Tensor& value) override;
141 void add_fetch(const string& name) override;
142 void add_target(const string& name) override;
143 RunOptions* mutable_options() override;
144 void set_store_errors_in_response_body(bool store_errors) override;
145
146 private:
147 string session_handle_;
148 string partial_run_handle_;
149 gtl::InlinedVector<std::pair<string, Tensor>, 4> feeds_;
150 gtl::InlinedVector<string, 4> fetches_;
151 gtl::InlinedVector<string, 4> targets_;
152 RunOptions options_;
153 bool store_errors_in_response_body_ = false;
154
155 // Holds a cached and owned representation of the proto
156 // representation of this request, if needed, so that `ToProto()`
157 // can return a const RunStepRequest&.
158 // NOTE(mrry): Although calls to `ToProto()` on this class are
159 // expected to be rare, retaining ownership of the returned message
160 // makes it easier to return a reference from the proto-backed
161 // representations.
162 mutable std::unique_ptr<RunStepRequest> proto_version_;
163};
164
165// Wrapper for mutable RunStep requests that uses a protobuf message.
166//
167// This wrapper class should be used for RunStep requests between a
168// client and master in different address spaces.
169class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper {
170 public:
171 // RunStepRequestWrapper methods.
172 const string& session_handle() const override;
173 const string& partial_run_handle() const override;
174 size_t num_feeds() const override;
175 const string& feed_name(size_t i) const override;
176 Status FeedValue(size_t i, Tensor* out_tensor) const override;
177 Status FeedValue(size_t i, TensorProto* out_tensor) const override;
178 size_t num_fetches() const override;
179 const string& fetch_name(size_t i) const override;
180 size_t num_targets() const override;
181 const string& target_name(size_t i) const override;
182 const RunOptions& options() const override;
183 string DebugString() const override;
184 const RunStepRequest& ToProto() const override;
185 bool store_errors_in_response_body() const override;
186 int64_t request_id() const override;
187
188 // MutableRunStepRequestWrapper methods.
189 void set_session_handle(const string& handle) override;
190 void set_partial_run_handle(const string& handle) override;
191 void add_feed(const string& name, const Tensor& value) override;
192 void add_fetch(const string& name) override;
193 void add_target(const string& name) override;
194 RunOptions* mutable_options() override;
195 void set_store_errors_in_response_body(bool store_errors) override;
196
197 private:
198 RunStepRequest request_;
199 friend class MasterInterface;
200};
201
202// Wrapper for immutable RunStep requests that use a non-owned
203// protobuf message.
204//
205// This interface is typically used by server-side components in the
206// TensorFlow master, where the incoming message is a (possibly const)
207// `RunStepRequest*`.
208class ProtoRunStepRequest : public RunStepRequestWrapper {
209 public:
210 ProtoRunStepRequest(const RunStepRequest* request);
211
212 // RunStepRequestWrapper methods.
213 const string& session_handle() const override;
214 const string& partial_run_handle() const override;
215 size_t num_feeds() const override;
216 const string& feed_name(size_t i) const override;
217 Status FeedValue(size_t i, Tensor* out_tensor) const override;
218 Status FeedValue(size_t i, TensorProto* out_tensor) const override;
219 size_t num_fetches() const override;
220 const string& fetch_name(size_t i) const override;
221 size_t num_targets() const override;
222 const string& target_name(size_t i) const override;
223 const RunOptions& options() const override;
224 string DebugString() const override;
225 const RunStepRequest& ToProto() const override;
226 bool store_errors_in_response_body() const override;
227 int64_t request_id() const override;
228
229 private:
230 const RunStepRequest* const request_; // Not owned.
231};
232
233////////////////////////////////////////////////////////////////////////////////
234//
235// Wrapper classes for the `WorkerService.RunGraph` request message.
236//
237// The `RunGraphRequest` message can contain potentially large tensor
238// data as part of its `send` submessages. Here we provide specialized
239// wrappers that avoid copying the tensor data wherever possible.
240//
241// See `RunGraphRequest` in tensorflow/core/protobuf/worker.proto for the
242// protocol buffer definition.
243//
244////////////////////////////////////////////////////////////////////////////////
245
246// Abstract interface for an immutable RunGraphRequest message.
247//
248// This interface is typically used by server-side components in the
249// TensorFlow worker.
250class RunGraphRequestWrapper {
251 public:
252 virtual ~RunGraphRequestWrapper() {}
253
254 // The session handle used to register the graph. If empty, a single global
255 // namespace is used.
256 virtual const string& session_handle() const = 0;
257
258 // Set to true if `CreateWorkerSession` was called for `session_handle`.
259 virtual bool create_worker_session_called() const = 0;
260
261 // REQUIRED: graph_handle must be returned by a RegisterGraph call
262 // to the same WorkerService.
263 virtual const string& graph_handle() const = 0;
264
265 // A unique ID to distinguish different runs of the same graph.
266 //
267 // The master generates a global unique `step_id` to distinguish
268 // different runs of the graph computation. Subgraphs communicate
269 // (e.g., send/recv ops) with each other using `step_id` to
270 // distinguish tensors generated by different runs.
271 virtual int64_t step_id() const = 0;
272
273 // Options for this step.
274 virtual const ExecutorOpts& exec_opts() const = 0;
275
276 // Sends the tensors in "send" into the graph before the run.
277 virtual size_t num_sends() const = 0;
278 virtual const string& send_key(size_t i) const = 0;
279 virtual Status SendValue(size_t i, Tensor* out_tensor) const = 0;
280
281 // Fetches the keys into `RunGraphResponse.recv` after the run.
282 virtual size_t num_recvs() const = 0;
283 virtual const string& recv_key(size_t i) const = 0;
284
285 // True if the RunGraphRequest is a partial run request.
286 virtual bool is_partial() const = 0;
287
288 // True if this is the last partial run request in a sequence of requests.
289 virtual bool is_last_partial_run() const = 0;
290
291 // If true then some errors, e.g., execution errors that have long
292 // error messages, may return an OK RunStepResponse with the actual
293 // error saved in the status_code/status_error_message fields of the
294 // response body. This is a workaround since the RPC subsystem may
295 // truncate long metadata messages.
296 virtual bool store_errors_in_response_body() const = 0;
297
298 virtual int64_t request_id() const = 0;
299
300 // Returns the wrapped data as a protocol buffer message.
301 virtual const RunGraphRequest& ToProto() const = 0;
302};
303
304// Abstract interface for a mutable RunGraphRequest message.
305//
306// See `RunGraphRequestWrapper` above for a description of the fields.
307class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
308 public:
309 virtual void set_session_handle(const string& handle) = 0;
310 virtual void set_create_worker_session_called(bool called) = 0;
311 virtual void set_graph_handle(const string& handle) = 0;
312 virtual void set_step_id(int64_t step_id) = 0;
313 virtual ExecutorOpts* mutable_exec_opts() = 0;
314
315 // Stores the i^{th} feed value in `run_step_request` in this
316 // request with the given `send_key`.
317 virtual Status AddSendFromRunStepRequest(
318 const RunStepRequestWrapper& run_step_request, size_t i,
319 const string& send_key) = 0;
320 virtual Status AddSendFromRunCallableRequest(
321 const RunCallableRequest& run_callable_request, size_t i,
322 const string& send_key) = 0;
323
324 virtual void add_recv_key(const string& recv_key) = 0;
325 virtual void set_is_partial(bool is_partial) = 0;
326 virtual void set_is_last_partial_run(bool is_last_partial_run) = 0;
327 virtual void set_store_errors_in_response_body(bool store_errors) = 0;
328 virtual void set_request_id(int64_t request_id) = 0;
329};
330
331class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
332 public:
333 // RunGraphRequestWrapper methods.
334 const string& session_handle() const override;
335 const string& graph_handle() const override;
336 bool create_worker_session_called() const override;
337 int64_t step_id() const override;
338 const ExecutorOpts& exec_opts() const override;
339 size_t num_sends() const override;
340 const string& send_key(size_t i) const override;
341 Status SendValue(size_t i, Tensor* out_tensor) const override;
342 size_t num_recvs() const override;
343 const string& recv_key(size_t i) const override;
344 bool is_partial() const override;
345 bool is_last_partial_run() const override;
346 const RunGraphRequest& ToProto() const override;
347 bool store_errors_in_response_body() const override;
348 int64_t request_id() const override;
349
350 // MutableRunGraphRequestWrapper methods.
351 void set_session_handle(const string& handle) override;
352 void set_create_worker_session_called(bool called) override;
353 void set_graph_handle(const string& handle) override;
354 void set_step_id(int64_t step_id) override;
355 ExecutorOpts* mutable_exec_opts() override;
356 Status AddSendFromRunStepRequest(
357 const RunStepRequestWrapper& run_step_request, size_t i,
358 const string& send_key) override;
359 Status AddSendFromRunCallableRequest(
360 const RunCallableRequest& run_callable_request, size_t i,
361 const string& send_key) override;
362 void add_recv_key(const string& recv_key) override;
363 void set_is_partial(bool is_partial) override;
364 void set_is_last_partial_run(bool is_last_partial_run) override;
365 void set_store_errors_in_response_body(bool store_errors) override;
366 void set_request_id(int64_t request_id) override;
367
368 private:
369 string session_handle_;
370 bool create_worker_session_called_ = false;
371 string graph_handle_;
372 int64_t step_id_;
373 ExecutorOpts exec_opts_;
374 gtl::InlinedVector<std::pair<string, Tensor>, 4> sends_;
375 gtl::InlinedVector<string, 4> recvs_;
376 bool is_partial_ = false;
377 bool is_last_partial_run_ = false;
378 bool store_errors_in_response_body_ = false;
379 int64_t request_id_ = 0;
380
381 // Holds a cached and owned representation of the proto
382 // representation of this request, if needed, so that `ToProto()`
383 // can return a const RunGraphRequest&.
384 // NOTE(mrry): Although calls to `ToProto()` on this class are
385 // expected to be rare, retaining ownership of the returned message
386 // makes it easier to return a reference from the proto-backed
387 // representations.
388 mutable std::unique_ptr<RunGraphRequest> proto_version_;
389};
390
391class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
392 public:
393 // RunGraphRequestWrapper methods.
394 const string& session_handle() const override;
395 bool create_worker_session_called() const override;
396 const string& graph_handle() const override;
397 int64_t step_id() const override;
398 const ExecutorOpts& exec_opts() const override;
399 size_t num_sends() const override;
400 const string& send_key(size_t i) const override;
401 Status SendValue(size_t i, Tensor* out_tensor) const override;
402 size_t num_recvs() const override;
403 const string& recv_key(size_t i) const override;
404 bool is_partial() const override;
405 bool is_last_partial_run() const override;
406 bool store_errors_in_response_body() const override;
407 int64_t request_id() const override;
408 const RunGraphRequest& ToProto() const override;
409
410 // MutableRunGraphRequestWrapper methods.
411 void set_session_handle(const string& handle) override;
412 void set_create_worker_session_called(bool called) override;
413 void set_graph_handle(const string& handle) override;
414 void set_step_id(int64_t step_id) override;
415 ExecutorOpts* mutable_exec_opts() override;
416 Status AddSendFromRunStepRequest(
417 const RunStepRequestWrapper& run_step_request, size_t i,
418 const string& send_key) override;
419 Status AddSendFromRunCallableRequest(
420 const RunCallableRequest& run_callable_request, size_t i,
421 const string& send_key) override;
422 void add_recv_key(const string& recv_key) override;
423 void set_is_partial(bool is_partial) override;
424 void set_is_last_partial_run(bool is_last_partial_run) override;
425 void set_store_errors_in_response_body(bool store_errors) override;
426 void set_request_id(int64_t request_id) override;
427
428 private:
429 RunGraphRequest request_;
430};
431
432class ProtoRunGraphRequest : public RunGraphRequestWrapper {
433 public:
434 ProtoRunGraphRequest(const RunGraphRequest* request);
435
436 // RunGraphRequestWrapper methods.
437 const string& session_handle() const override;
438 bool create_worker_session_called() const override;
439 const string& graph_handle() const override;
440 int64_t step_id() const override;
441 const ExecutorOpts& exec_opts() const override;
442 size_t num_sends() const override;
443 const string& send_key(size_t i) const override;
444 Status SendValue(size_t i, Tensor* out_tensor) const override;
445 size_t num_recvs() const override;
446 const string& recv_key(size_t i) const override;
447 bool is_partial() const override;
448 bool is_last_partial_run() const override;
449 bool store_errors_in_response_body() const override;
450 int64_t request_id() const override;
451 const RunGraphRequest& ToProto() const override;
452
453 private:
454 const RunGraphRequest* const request_; // Not owned.
455};
456
457////////////////////////////////////////////////////////////////////////////////
458//
459// Wrapper classes for the `WorkerService.RunGraph` response message.
460//
461// The `RunGraphResponse` message can contain potentially large tensor
462// data as part of its `recv` submessages. Here we provide specialized
463// wrappers that avoid copying the tensor data wherever possible.
464//
465// See `RunGraphResponse` in tensorflow/core/protobuf/worker.proto for the
466// protocol buffer definition.
467//
468////////////////////////////////////////////////////////////////////////////////
469
470// Abstract interface for a mutable RunGraphResponse message.
471//
472// Note that there is no corresponding (immutable)
473// RunGraphResponseWrapper class, because the RunGraphResponse object
474// is always used as a mutable pointer.
475class MutableRunGraphResponseWrapper {
476 public:
477 virtual ~MutableRunGraphResponseWrapper() {}
478
479 // A list of tensors corresponding to those requested by
480 // `RunGraphRequest.recv_key`.
481 virtual size_t num_recvs() const = 0;
482 virtual const string& recv_key(size_t i) const = 0;
483 // NOTE: The following methods may perform a destructive read, for
484 // efficiency.
485 virtual Status RecvValue(size_t i, TensorProto* out_tensor) = 0;
486 virtual Status RecvValue(size_t i, Tensor* out_tensor) = 0;
487 virtual void AddRecv(const string& key, const Tensor& value) = 0;
488
489 // Submessages that store performance statistics about the subgraph
490 // execution, if necessary.
491 virtual StepStats* mutable_step_stats() = 0;
492 virtual CostGraphDef* mutable_cost_graph() = 0;
493 virtual size_t num_partition_graphs() const = 0;
494 virtual GraphDef* mutable_partition_graph(size_t i) = 0;
495 virtual void AddPartitionGraph(const GraphDef& partition_graph) = 0;
496
497 // Returned status if requested.
498 virtual Status status() const = 0;
499 virtual errors::Code status_code() const = 0;
500 virtual const string& status_error_message() const = 0;
501 virtual void set_status(const Status& status) = 0;
502
503 protected:
504 // Returns a mutable protobuf message that represents the contents of
505 // this wrapper, for passing to an RPC subsystem that will populate
506 // the message.
507 //
508 // NOTE: Only `WorkerInterface` subclasses may call this method. The
509 // `InMemoryRunGraphResponse` subclass does not implement this
510 // method, and attempts to call it will fail with a fatal
511 // error. However, as long as callers always call
512 // `WorkerInterface::RunGraphAsync()` with a wrapper object returned
513 // from `WorkerInterface::CreateRunGraphResponse()` called on the
514 // *same* WorkerInterface object, this error will never trigger.
515 virtual RunGraphResponse* get_proto() = 0;
516 friend class WorkerInterface;
517};
518
519class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper {
520 public:
521 // MutableRunGraphResponseWrapper methods.
522 size_t num_recvs() const override;
523 const string& recv_key(size_t i) const override;
524 Status RecvValue(size_t i, TensorProto* out_tensor) override;
525 Status RecvValue(size_t i, Tensor* out_tensor) override;
526 void AddRecv(const string& key, const Tensor& value) override;
527 StepStats* mutable_step_stats() override;
528 CostGraphDef* mutable_cost_graph() override;
529 size_t num_partition_graphs() const override;
530 GraphDef* mutable_partition_graph(size_t i) override;
531 void AddPartitionGraph(const GraphDef& partition_graph) override;
532 Status status() const override;
533 errors::Code status_code() const override;
534 const string& status_error_message() const override;
535 void set_status(const Status& status) override;
536
537 protected:
538 // NOTE: This method is not implemented. See
539 // MutableRunGraphResponseWrapper for an explanation.
540 RunGraphResponse* get_proto() override;
541
542 private:
543 gtl::InlinedVector<std::pair<string, Tensor>, 4> recvs_;
544 StepStats step_stats_;
545 CostGraphDef cost_graph_;
546 std::vector<GraphDef> partition_graphs_;
547 // Store the code and message separately so that they can be updated
548 // independently by setters.
549 Status status_;
550};
551
552// Proto-based message wrapper for use on the client side of the RunGraph RPC.
553class OwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper {
554 public:
555 // MutableRunGraphResponseWrapper methods.
556 size_t num_recvs() const override;
557 const string& recv_key(size_t i) const override;
558 Status RecvValue(size_t i, TensorProto* out_tensor) override;
559 Status RecvValue(size_t i, Tensor* out_tensor) override;
560 void AddRecv(const string& key, const Tensor& value) override;
561 StepStats* mutable_step_stats() override;
562 CostGraphDef* mutable_cost_graph() override;
563 size_t num_partition_graphs() const override;
564 GraphDef* mutable_partition_graph(size_t i) override;
565 void AddPartitionGraph(const GraphDef& partition_graph) override;
566 Status status() const override;
567 errors::Code status_code() const override;
568 const string& status_error_message() const override;
569 void set_status(const Status& status) override;
570
571 protected:
572 RunGraphResponse* get_proto() override;
573
574 private:
575 RunGraphResponse response_;
576};
577
578// Proto-based message wrapper for use on the server side of the RunGraph RPC.
579class NonOwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper {
580 public:
581 NonOwnedProtoRunGraphResponse(RunGraphResponse* response);
582
583 // MutableRunGraphResponseWrapper methods.
584 size_t num_recvs() const override;
585 const string& recv_key(size_t i) const override;
586 Status RecvValue(size_t i, TensorProto* out_tensor) override;
587 Status RecvValue(size_t i, Tensor* out_tensor) override;
588 void AddRecv(const string& key, const Tensor& value) override;
589 StepStats* mutable_step_stats() override;
590 CostGraphDef* mutable_cost_graph() override;
591 size_t num_partition_graphs() const override;
592 GraphDef* mutable_partition_graph(size_t i) override;
593 void AddPartitionGraph(const GraphDef& partition_graph) override;
594 Status status() const override;
595 errors::Code status_code() const override;
596 const string& status_error_message() const override;
597 void set_status(const Status& status) override;
598
599 protected:
600 RunGraphResponse* get_proto() override;
601
602 private:
603 RunGraphResponse* const response_;
604};
605
606////////////////////////////////////////////////////////////////////////////////
607//
608// Wrapper classes for the `MasterService.RunStep` response message.
609//
610// The `RunStepResponse` message can contain potentially large tensor
611// data as part of its `tensor` submessages. Here we provide specialized
612// wrappers that avoid copying the tensor data wherever possible.
613//
614// See `RunStepResponse` in tensorflow/core/protobuf/master.proto for the
615// protocol buffer definition.
616//
617////////////////////////////////////////////////////////////////////////////////
618
619// Abstract interface for a mutable RunStepResponse message.
620//
621// Note that there is no corresponding (immutable)
622// RunStepResponseWrapper class, because the RunStepResponse object is
623// always used as a mutable pointer.
624class MutableRunStepResponseWrapper {
625 public:
626 virtual ~MutableRunStepResponseWrapper();
627
628 // The values of the tensors whose fetching was requested in the
629 // RunStep call.
630 //
631 // NOTE: The order of the returned tensors may or may not match
632 // the fetch order specified in RunStepRequest.
633 virtual size_t num_tensors() const = 0;
634 virtual const string& tensor_name(size_t i) const = 0;
635 virtual Status TensorValue(size_t i, Tensor* out_tensor) const = 0;
636
637 // Stores the i^{th} recv value in `run_graph_response` in this
638 // response with the given `name`.
639 virtual Status AddTensorFromRunGraphResponse(
640 const string& name, MutableRunGraphResponseWrapper* run_graph_response,
641 size_t i) = 0;
642
643 // Returned metadata if requested in the options.
644 virtual const RunMetadata& metadata() const = 0;
645 virtual RunMetadata* mutable_metadata() = 0;
646
647 // Returned status if requested.
648 virtual Status status() const = 0;
649 virtual errors::Code status_code() const = 0;
650 virtual const string& status_error_message() const = 0;
651 virtual void set_status(const Status& status) = 0;
652
653 protected:
654 // Returns a mutable protobuf message that represents the contents of
655 // this wrapper, for passing to an RPC subsystem that will populate
656 // the message.
657 //
658 // NOTE: Only `MasterInterface` subclasses may call this method. The
659 // `InMemoryRunStepResponse` subclass does not implement this
660 // method, and attempts to call it will fail with a fatal
661 // error. However, as long as callers always call
662 // `MasterInterface::RunStep()` with a wrapper object returned
663 // from `MasterInterface::CreateRunStepResponse()` called on the
664 // *same* MasterInterface object, this error will never trigger.
665 virtual RunStepResponse* get_proto() = 0;
666 friend class MasterInterface;
667};
668
669class InMemoryRunStepResponse : public MutableRunStepResponseWrapper {
670 public:
671 // MutableRunStepResponseWrapper methods.
672 size_t num_tensors() const override;
673 const string& tensor_name(size_t i) const override;
674 Status TensorValue(size_t i, Tensor* out_tensor) const override;
675 Status AddTensorFromRunGraphResponse(
676 const string& name, MutableRunGraphResponseWrapper* run_graph_response,
677 size_t i) override;
678 const RunMetadata& metadata() const override;
679 RunMetadata* mutable_metadata() override;
680 Status status() const override;
681 errors::Code status_code() const override;
682 const string& status_error_message() const override;
683 void set_status(const Status& status) override;
684
685 protected:
686 // NOTE: This method is not implemented. See
687 // MutableRunGraphResponseWrapper for an explanation.
688 RunStepResponse* get_proto() override;
689
690 private:
691 gtl::InlinedVector<std::pair<string, Tensor>, 4> tensors_;
692 RunMetadata metadata_;
693 // Store the code and message separately so that they can be updated
694 // independently by setters.
695 Status status_;
696};
697
698// Proto-based message wrapper for use on the client side of the RunStep RPC.
699class OwnedProtoRunStepResponse : public MutableRunStepResponseWrapper {
700 public:
701 // MutableRunStepResponseWrapper methods.
702 size_t num_tensors() const override;
703 const string& tensor_name(size_t i) const override;
704 Status TensorValue(size_t i, Tensor* out_tensor) const override;
705 Status AddTensorFromRunGraphResponse(
706 const string& name, MutableRunGraphResponseWrapper* run_graph_response,
707 size_t i) override;
708 const RunMetadata& metadata() const override;
709 RunMetadata* mutable_metadata() override;
710 Status status() const override;
711 errors::Code status_code() const override;
712 const string& status_error_message() const override;
713 void set_status(const Status& status) override;
714
715 protected:
716 RunStepResponse* get_proto() override;
717
718 private:
719 RunStepResponse response_;
720};
721
722// Proto-based message wrapper for use on the server side of the RunStep RPC.
723class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper {
724 public:
725 NonOwnedProtoRunStepResponse(RunStepResponse* response);
726
727 // MutableRunStepResponseWrapper methods.
728 size_t num_tensors() const override;
729 const string& tensor_name(size_t i) const override;
730 Status TensorValue(size_t i, Tensor* out_tensor) const override;
731 Status AddTensorFromRunGraphResponse(
732 const string& name, MutableRunGraphResponseWrapper* run_graph_response,
733 size_t i) override;
734 const RunMetadata& metadata() const override;
735 RunMetadata* mutable_metadata() override;
736 Status status() const override;
737 errors::Code status_code() const override;
738 const string& status_error_message() const override;
739 void set_status(const Status& status) override;
740
741 protected:
742 RunStepResponse* get_proto() override;
743
744 private:
745 RunStepResponse* response_; // Not owned.
746};
747
748bool ParseTensorProtoToTensor(const TensorProto& tensor_proto,
749 Tensor* out_tensor);
750
751} // namespace tensorflow
752
753#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
754