1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/distributed_runtime/message_wrappers.h"
17
18#include "tensorflow/core/framework/cost_graph.pb.h"
19#include "tensorflow/core/framework/step_stats.pb.h"
20#include "tensorflow/core/framework/tensor.pb.h"
21#include "tensorflow/core/protobuf/config.pb.h"
22#include "tensorflow/core/protobuf/named_tensor.pb.h"
23
24namespace tensorflow {
25
26bool ParseTensorProtoToTensor(const TensorProto& tensor_proto,
27 Tensor* out_tensor) {
28 if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
29 Tensor parsed(tensor_proto.dtype());
30 if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
31 *out_tensor = parsed;
32 return true;
33 }
34 }
35 return false;
36}
37
38const string& InMemoryRunStepRequest::session_handle() const {
39 return session_handle_;
40}
41
42void InMemoryRunStepRequest::set_session_handle(const string& handle) {
43 session_handle_ = handle;
44}
45
46const string& InMemoryRunStepRequest::partial_run_handle() const {
47 return partial_run_handle_;
48}
49
50void InMemoryRunStepRequest::set_partial_run_handle(const string& handle) {
51 partial_run_handle_ = handle;
52}
53
54size_t InMemoryRunStepRequest::num_feeds() const { return feeds_.size(); }
55const string& InMemoryRunStepRequest::feed_name(size_t i) const {
56 return feeds_[i].first;
57}
58
59Status InMemoryRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const {
60 *out_tensor = feeds_[i].second;
61 return OkStatus();
62}
63
64Status InMemoryRunStepRequest::FeedValue(size_t i,
65 TensorProto* out_tensor) const {
66 feeds_[i].second.AsProtoTensorContent(out_tensor);
67 return OkStatus();
68}
69
70void InMemoryRunStepRequest::add_feed(const string& name, const Tensor& value) {
71 feeds_.emplace_back(name, value);
72}
73
74size_t InMemoryRunStepRequest::num_fetches() const { return fetches_.size(); }
75const string& InMemoryRunStepRequest::fetch_name(size_t i) const {
76 return fetches_[i];
77}
78void InMemoryRunStepRequest::add_fetch(const string& name) {
79 fetches_.push_back(name);
80}
81
82size_t InMemoryRunStepRequest::num_targets() const { return targets_.size(); }
83const string& InMemoryRunStepRequest::target_name(size_t i) const {
84 return targets_[i];
85}
86void InMemoryRunStepRequest::add_target(const string& name) {
87 targets_.push_back(name);
88}
89
90const RunOptions& InMemoryRunStepRequest::options() const { return options_; }
91
92RunOptions* InMemoryRunStepRequest::mutable_options() { return &options_; }
93
94bool InMemoryRunStepRequest::store_errors_in_response_body() const {
95 return store_errors_in_response_body_;
96}
97
98int64_t InMemoryRunStepRequest::request_id() const {
99 return 0; // no need to track request id for local version.
100}
101
102void InMemoryRunStepRequest::set_store_errors_in_response_body(
103 bool store_errors) {
104 store_errors_in_response_body_ = store_errors;
105}
106
107string InMemoryRunStepRequest::DebugString() const {
108 return ToProto().DebugString();
109}
110
111const RunStepRequest& InMemoryRunStepRequest::ToProto() const {
112 if (!proto_version_) {
113 proto_version_.reset(new RunStepRequest);
114 proto_version_->set_session_handle(session_handle());
115 proto_version_->set_partial_run_handle(partial_run_handle());
116 for (size_t i = 0; i < num_feeds(); ++i) {
117 auto feed = proto_version_->add_feed();
118 feed->set_name(feed_name(i));
119 feeds_[i].second.AsProtoTensorContent(feed->mutable_tensor());
120 }
121 for (size_t i = 0; i < num_fetches(); ++i) {
122 proto_version_->add_fetch(fetch_name(i));
123 }
124 for (size_t i = 0; i < num_targets(); ++i) {
125 proto_version_->add_target(target_name(i));
126 }
127 *proto_version_->mutable_options() = options();
128 }
129 return *proto_version_;
130}
131
132const string& MutableProtoRunStepRequest::session_handle() const {
133 return request_.session_handle();
134}
135void MutableProtoRunStepRequest::set_session_handle(const string& handle) {
136 request_.set_session_handle(handle);
137}
138
139const string& MutableProtoRunStepRequest::partial_run_handle() const {
140 return request_.partial_run_handle();
141}
142void MutableProtoRunStepRequest::set_partial_run_handle(const string& handle) {
143 request_.set_partial_run_handle(handle);
144}
145
146size_t MutableProtoRunStepRequest::num_feeds() const {
147 return request_.feed_size();
148}
149const string& MutableProtoRunStepRequest::feed_name(size_t i) const {
150 return request_.feed(i).name();
151}
152Status MutableProtoRunStepRequest::FeedValue(size_t i,
153 Tensor* out_tensor) const {
154 if (!ParseTensorProtoToTensor(request_.feed(i).tensor(), out_tensor)) {
155 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
156 } else {
157 return OkStatus();
158 }
159}
160
161Status MutableProtoRunStepRequest::FeedValue(size_t i,
162 TensorProto* out_tensor) const {
163 *out_tensor = request_.feed(i).tensor();
164 return OkStatus();
165}
166
167void MutableProtoRunStepRequest::add_feed(const string& name,
168 const Tensor& value) {
169 NamedTensorProto* feed = request_.add_feed();
170 feed->set_name(name);
171 TensorProto* value_proto = feed->mutable_tensor();
172 value.AsProtoTensorContent(value_proto);
173}
174
175size_t MutableProtoRunStepRequest::num_fetches() const {
176 return request_.fetch_size();
177}
178
179const string& MutableProtoRunStepRequest::fetch_name(size_t i) const {
180 return request_.fetch(i);
181}
182void MutableProtoRunStepRequest::add_fetch(const string& name) {
183 request_.add_fetch(name);
184}
185
186size_t MutableProtoRunStepRequest::num_targets() const {
187 return request_.target_size();
188}
189
190const string& MutableProtoRunStepRequest::target_name(size_t i) const {
191 return request_.target(i);
192}
193
194void MutableProtoRunStepRequest::add_target(const string& name) {
195 request_.add_target(name);
196}
197
198const RunOptions& MutableProtoRunStepRequest::options() const {
199 return request_.options();
200}
201
202RunOptions* MutableProtoRunStepRequest::mutable_options() {
203 return request_.mutable_options();
204}
205
206bool MutableProtoRunStepRequest::store_errors_in_response_body() const {
207 return request_.store_errors_in_response_body();
208}
209
210void MutableProtoRunStepRequest::set_store_errors_in_response_body(
211 bool store_errors) {
212 request_.set_store_errors_in_response_body(store_errors);
213}
214
215int64_t MutableProtoRunStepRequest::request_id() const {
216 return request_.request_id();
217}
218
219string MutableProtoRunStepRequest::DebugString() const {
220 return request_.DebugString();
221}
222
223const RunStepRequest& MutableProtoRunStepRequest::ToProto() const {
224 return request_;
225}
226
227ProtoRunStepRequest::ProtoRunStepRequest(const RunStepRequest* request)
228 : request_(request) {}
229
230const string& ProtoRunStepRequest::session_handle() const {
231 return request_->session_handle();
232}
233
234const string& ProtoRunStepRequest::partial_run_handle() const {
235 return request_->partial_run_handle();
236}
237
238size_t ProtoRunStepRequest::num_feeds() const { return request_->feed_size(); }
239
240const string& ProtoRunStepRequest::feed_name(size_t i) const {
241 return request_->feed(i).name();
242}
243
244Status ProtoRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const {
245 if (!ParseTensorProtoToTensor(request_->feed(i).tensor(), out_tensor)) {
246 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
247 } else {
248 return OkStatus();
249 }
250}
251
252Status ProtoRunStepRequest::FeedValue(size_t i, TensorProto* out_tensor) const {
253 *out_tensor = request_->feed(i).tensor();
254 return OkStatus();
255}
256
257size_t ProtoRunStepRequest::num_fetches() const {
258 return request_->fetch_size();
259}
260
261const string& ProtoRunStepRequest::fetch_name(size_t i) const {
262 return request_->fetch(i);
263}
264
265size_t ProtoRunStepRequest::num_targets() const {
266 return request_->target_size();
267}
268
269const string& ProtoRunStepRequest::target_name(size_t i) const {
270 return request_->target(i);
271}
272
273const RunOptions& ProtoRunStepRequest::options() const {
274 return request_->options();
275}
276
277bool ProtoRunStepRequest::store_errors_in_response_body() const {
278 return request_->store_errors_in_response_body();
279}
280
281int64_t ProtoRunStepRequest::request_id() const {
282 return request_->request_id();
283}
284
285string ProtoRunStepRequest::DebugString() const {
286 return request_->DebugString();
287}
288
289const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; }
290
291const string& InMemoryRunGraphRequest::session_handle() const {
292 return session_handle_;
293}
294
295bool InMemoryRunGraphRequest::create_worker_session_called() const {
296 return create_worker_session_called_;
297}
298
299void InMemoryRunGraphRequest::set_session_handle(const string& handle) {
300 session_handle_ = handle;
301}
302
303void InMemoryRunGraphRequest::set_create_worker_session_called(bool called) {
304 create_worker_session_called_ = called;
305}
306
307const string& InMemoryRunGraphRequest::graph_handle() const {
308 return graph_handle_;
309}
310
311void InMemoryRunGraphRequest::set_graph_handle(const string& handle) {
312 graph_handle_ = handle;
313}
314
315int64_t InMemoryRunGraphRequest::step_id() const { return step_id_; }
316
317void InMemoryRunGraphRequest::set_step_id(int64_t step_id) {
318 step_id_ = step_id;
319}
320
321const ExecutorOpts& InMemoryRunGraphRequest::exec_opts() const {
322 return exec_opts_;
323}
324
325ExecutorOpts* InMemoryRunGraphRequest::mutable_exec_opts() {
326 return &exec_opts_;
327}
328
329size_t InMemoryRunGraphRequest::num_sends() const { return sends_.size(); }
330
331const string& InMemoryRunGraphRequest::send_key(size_t i) const {
332 return sends_[i].first;
333}
334
335Status InMemoryRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const {
336 *out_tensor = sends_[i].second;
337 return OkStatus();
338}
339
340Status InMemoryRunGraphRequest::AddSendFromRunStepRequest(
341 const RunStepRequestWrapper& run_step_request, size_t i,
342 const string& send_key) {
343 Tensor tensor;
344 TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, &tensor));
345 sends_.emplace_back(send_key, std::move(tensor));
346 return OkStatus();
347}
348
349Status InMemoryRunGraphRequest::AddSendFromRunCallableRequest(
350 const RunCallableRequest& run_callable_request, size_t i,
351 const string& send_key) {
352 Tensor tensor;
353 if (!ParseTensorProtoToTensor(run_callable_request.feed(i), &tensor)) {
354 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
355 }
356 sends_.emplace_back(send_key, std::move(tensor));
357 return OkStatus();
358}
359
360size_t InMemoryRunGraphRequest::num_recvs() const { return recvs_.size(); }
361
362const string& InMemoryRunGraphRequest::recv_key(size_t i) const {
363 return recvs_[i];
364}
365
366void InMemoryRunGraphRequest::add_recv_key(const string& recv_key) {
367 recvs_.push_back(recv_key);
368}
369
370bool InMemoryRunGraphRequest::is_partial() const { return is_partial_; }
371
372void InMemoryRunGraphRequest::set_is_partial(bool is_partial) {
373 is_partial_ = is_partial;
374}
375
376bool InMemoryRunGraphRequest::is_last_partial_run() const {
377 return is_last_partial_run_;
378}
379
380void InMemoryRunGraphRequest::set_is_last_partial_run(
381 bool is_last_partial_run) {
382 is_last_partial_run_ = is_last_partial_run;
383}
384
385bool InMemoryRunGraphRequest::store_errors_in_response_body() const {
386 return store_errors_in_response_body_;
387}
388
389void InMemoryRunGraphRequest::set_store_errors_in_response_body(
390 bool store_errors) {
391 store_errors_in_response_body_ = store_errors;
392}
393
394int64_t InMemoryRunGraphRequest::request_id() const { return request_id_; }
395
396void InMemoryRunGraphRequest::set_request_id(int64_t request_id) {
397 request_id_ = request_id;
398}
399
400const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
401 if (!proto_version_) {
402 proto_version_.reset(new RunGraphRequest);
403 proto_version_->set_session_handle(session_handle());
404 proto_version_->set_create_worker_session_called(
405 create_worker_session_called());
406 proto_version_->set_graph_handle(graph_handle());
407 proto_version_->set_step_id(step_id());
408 *proto_version_->mutable_exec_opts() = exec_opts();
409 for (size_t i = 0; i < num_sends(); ++i) {
410 auto send = proto_version_->add_send();
411 send->set_name(send_key(i));
412 sends_[i].second.AsProtoTensorContent(send->mutable_tensor());
413 }
414 for (size_t i = 0; i < num_recvs(); ++i) {
415 proto_version_->add_recv_key(recv_key(i));
416 }
417 proto_version_->set_is_partial(is_partial());
418 proto_version_->set_is_last_partial_run(is_last_partial_run());
419 }
420 proto_version_->set_store_errors_in_response_body(
421 store_errors_in_response_body_);
422 proto_version_->set_request_id(request_id_);
423 return *proto_version_;
424}
425
426const string& MutableProtoRunGraphRequest::session_handle() const {
427 return request_.session_handle();
428}
429
430void MutableProtoRunGraphRequest::set_session_handle(const string& handle) {
431 request_.set_session_handle(handle);
432}
433
434bool MutableProtoRunGraphRequest::create_worker_session_called() const {
435 return request_.create_worker_session_called();
436}
437
438void MutableProtoRunGraphRequest::set_create_worker_session_called(
439 bool called) {
440 request_.set_create_worker_session_called(called);
441}
442
443const string& MutableProtoRunGraphRequest::graph_handle() const {
444 return request_.graph_handle();
445}
446
447void MutableProtoRunGraphRequest::set_graph_handle(const string& handle) {
448 request_.set_graph_handle(handle);
449}
450
451int64_t MutableProtoRunGraphRequest::step_id() const {
452 return request_.step_id();
453}
454
455void MutableProtoRunGraphRequest::set_step_id(int64_t step_id) {
456 request_.set_step_id(step_id);
457}
458
459const ExecutorOpts& MutableProtoRunGraphRequest::exec_opts() const {
460 return request_.exec_opts();
461}
462
463ExecutorOpts* MutableProtoRunGraphRequest::mutable_exec_opts() {
464 return request_.mutable_exec_opts();
465}
466
467size_t MutableProtoRunGraphRequest::num_sends() const {
468 return request_.send_size();
469}
470
471const string& MutableProtoRunGraphRequest::send_key(size_t i) const {
472 return request_.send(i).name();
473}
474
475Status MutableProtoRunGraphRequest::SendValue(size_t i,
476 Tensor* out_tensor) const {
477 if (!ParseTensorProtoToTensor(request_.send(i).tensor(), out_tensor)) {
478 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
479 } else {
480 return OkStatus();
481 }
482}
483
484Status MutableProtoRunGraphRequest::AddSendFromRunStepRequest(
485 const RunStepRequestWrapper& run_step_request, size_t i,
486 const string& send_key) {
487 NamedTensorProto* send = request_.add_send();
488 send->set_name(send_key);
489 TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, send->mutable_tensor()));
490 return OkStatus();
491}
492
493Status MutableProtoRunGraphRequest::AddSendFromRunCallableRequest(
494 const RunCallableRequest& run_callable_request, size_t i,
495 const string& send_key) {
496 NamedTensorProto* send = request_.add_send();
497 send->set_name(send_key);
498 *send->mutable_tensor() = run_callable_request.feed(i);
499 return OkStatus();
500}
501
502size_t MutableProtoRunGraphRequest::num_recvs() const {
503 return request_.recv_key_size();
504}
505
506const string& MutableProtoRunGraphRequest::recv_key(size_t i) const {
507 return request_.recv_key(i);
508}
509
510void MutableProtoRunGraphRequest::add_recv_key(const string& recv_key) {
511 request_.add_recv_key(recv_key);
512}
513
514bool MutableProtoRunGraphRequest::is_partial() const {
515 return request_.is_partial();
516}
517
518void MutableProtoRunGraphRequest::set_is_partial(bool is_partial) {
519 request_.set_is_partial(is_partial);
520}
521
522bool MutableProtoRunGraphRequest::is_last_partial_run() const {
523 return request_.is_last_partial_run();
524}
525
526void MutableProtoRunGraphRequest::set_is_last_partial_run(
527 bool is_last_partial_run) {
528 request_.set_is_last_partial_run(is_last_partial_run);
529}
530
531bool MutableProtoRunGraphRequest::store_errors_in_response_body() const {
532 return request_.store_errors_in_response_body();
533}
534
535void MutableProtoRunGraphRequest::set_store_errors_in_response_body(
536 bool store_errors) {
537 request_.set_store_errors_in_response_body(store_errors);
538}
539
540int64_t MutableProtoRunGraphRequest::request_id() const {
541 return request_.request_id();
542}
543
544void MutableProtoRunGraphRequest::set_request_id(int64_t request_id) {
545 request_.set_request_id(request_id);
546}
547
548const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const {
549 return request_;
550}
551
552ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request)
553 : request_(request) {}
554
555const string& ProtoRunGraphRequest::session_handle() const {
556 return request_->session_handle();
557}
558
559bool ProtoRunGraphRequest::create_worker_session_called() const {
560 return request_->create_worker_session_called();
561}
562
563const string& ProtoRunGraphRequest::graph_handle() const {
564 return request_->graph_handle();
565}
566
567int64_t ProtoRunGraphRequest::step_id() const { return request_->step_id(); }
568
569const ExecutorOpts& ProtoRunGraphRequest::exec_opts() const {
570 return request_->exec_opts();
571}
572
573size_t ProtoRunGraphRequest::num_sends() const { return request_->send_size(); }
574
575const string& ProtoRunGraphRequest::send_key(size_t i) const {
576 return request_->send(i).name();
577}
578
579Status ProtoRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const {
580 if (!ParseTensorProtoToTensor(request_->send(i).tensor(), out_tensor)) {
581 return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
582 } else {
583 return OkStatus();
584 }
585}
586
587size_t ProtoRunGraphRequest::num_recvs() const {
588 return request_->recv_key_size();
589}
590
591const string& ProtoRunGraphRequest::recv_key(size_t i) const {
592 return request_->recv_key(i);
593}
594
595bool ProtoRunGraphRequest::is_partial() const { return request_->is_partial(); }
596
597bool ProtoRunGraphRequest::is_last_partial_run() const {
598 return request_->is_last_partial_run();
599}
600
601bool ProtoRunGraphRequest::store_errors_in_response_body() const {
602 return request_->store_errors_in_response_body();
603}
604
605int64_t ProtoRunGraphRequest::request_id() const {
606 return request_->request_id();
607}
608
609const RunGraphRequest& ProtoRunGraphRequest::ToProto() const {
610 return *request_;
611}
612
613size_t InMemoryRunGraphResponse::num_recvs() const { return recvs_.size(); }
614
615const string& InMemoryRunGraphResponse::recv_key(size_t i) const {
616 return recvs_[i].first;
617}
618
619Status InMemoryRunGraphResponse::RecvValue(size_t i, TensorProto* out_tensor) {
620 recvs_[i].second.AsProtoTensorContent(out_tensor);
621 return OkStatus();
622}
623
624Status InMemoryRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
625 *out_tensor = recvs_[i].second;
626 return OkStatus();
627}
628
629void InMemoryRunGraphResponse::AddRecv(const string& key, const Tensor& value) {
630 recvs_.emplace_back(key, value);
631}
632
633StepStats* InMemoryRunGraphResponse::mutable_step_stats() {
634 return &step_stats_;
635}
636
637CostGraphDef* InMemoryRunGraphResponse::mutable_cost_graph() {
638 return &cost_graph_;
639}
640
641Status InMemoryRunGraphResponse::status() const { return status_; }
642
643errors::Code InMemoryRunGraphResponse::status_code() const {
644 return status_.code();
645}
646
647const string& InMemoryRunGraphResponse::status_error_message() const {
648 return status_.error_message();
649}
650
651void InMemoryRunGraphResponse::set_status(const Status& status) {
652 status_ = status;
653}
654
655RunGraphResponse* InMemoryRunGraphResponse::get_proto() {
656 LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunGraphResponse";
657 return nullptr;
658}
659
660size_t InMemoryRunGraphResponse::num_partition_graphs() const {
661 return partition_graphs_.size();
662}
663
664GraphDef* InMemoryRunGraphResponse::mutable_partition_graph(size_t i) {
665 return &partition_graphs_[i];
666}
667
668void InMemoryRunGraphResponse::AddPartitionGraph(
669 const GraphDef& partition_graph) {
670 partition_graphs_.push_back(partition_graph);
671}
672
673size_t OwnedProtoRunGraphResponse::num_recvs() const {
674 return response_.recv_size();
675}
676
677const string& OwnedProtoRunGraphResponse::recv_key(size_t i) const {
678 return response_.recv(i).name();
679}
680
681Status OwnedProtoRunGraphResponse::RecvValue(size_t i,
682 TensorProto* out_tensor) {
683 out_tensor->Swap(response_.mutable_recv(i)->mutable_tensor());
684 return OkStatus();
685}
686
687Status OwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
688 if (!ParseTensorProtoToTensor(response_.recv(i).tensor(), out_tensor)) {
689 return errors::InvalidArgument("Invalid TensorProto for recv value ", i);
690 } else {
691 return OkStatus();
692 }
693}
694
695void OwnedProtoRunGraphResponse::AddRecv(const string& key,
696 const Tensor& value) {
697 NamedTensorProto* recv = response_.add_recv();
698 recv->set_name(key);
699 TensorProto* value_proto = recv->mutable_tensor();
700 value.AsProtoTensorContent(value_proto);
701}
702
703StepStats* OwnedProtoRunGraphResponse::mutable_step_stats() {
704 return response_.mutable_step_stats();
705}
706
707CostGraphDef* OwnedProtoRunGraphResponse::mutable_cost_graph() {
708 return response_.mutable_cost_graph();
709}
710
711Status OwnedProtoRunGraphResponse::status() const {
712 return Status(response_.status_code(), response_.status_error_message());
713}
714
715errors::Code OwnedProtoRunGraphResponse::status_code() const {
716 return response_.status_code();
717}
718
719const string& OwnedProtoRunGraphResponse::status_error_message() const {
720 return response_.status_error_message();
721}
722
723void OwnedProtoRunGraphResponse::set_status(const Status& status) {
724 response_.set_status_code(status.code());
725 response_.set_status_error_message(status.error_message());
726}
727
728RunGraphResponse* OwnedProtoRunGraphResponse::get_proto() { return &response_; }
729
730size_t OwnedProtoRunGraphResponse::num_partition_graphs() const {
731 return response_.partition_graph_size();
732}
733
734GraphDef* OwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) {
735 return response_.mutable_partition_graph(i);
736}
737
738void OwnedProtoRunGraphResponse::AddPartitionGraph(
739 const GraphDef& partition_graph) {
740 GraphDef* graph_def = response_.mutable_partition_graph()->Add();
741 *graph_def = partition_graph;
742}
743
744NonOwnedProtoRunGraphResponse::NonOwnedProtoRunGraphResponse(
745 RunGraphResponse* response)
746 : response_(response) {}
747
748size_t NonOwnedProtoRunGraphResponse::num_recvs() const {
749 return response_->recv_size();
750}
751
752const string& NonOwnedProtoRunGraphResponse::recv_key(size_t i) const {
753 return response_->recv(i).name();
754}
755
756Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i,
757 TensorProto* out_tensor) {
758 out_tensor->Swap(response_->mutable_recv(i)->mutable_tensor());
759 return OkStatus();
760}
761
762Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
763 if (!ParseTensorProtoToTensor(response_->recv(i).tensor(), out_tensor)) {
764 return errors::InvalidArgument("Invalid TensorProto for recv value ", i);
765 } else {
766 return OkStatus();
767 }
768}
769
770void NonOwnedProtoRunGraphResponse::AddRecv(const string& key,
771 const Tensor& value) {
772 NamedTensorProto* recv = response_->add_recv();
773 recv->set_name(key);
774 TensorProto* value_proto = recv->mutable_tensor();
775 value.AsProtoTensorContent(value_proto);
776}
777
778StepStats* NonOwnedProtoRunGraphResponse::mutable_step_stats() {
779 return response_->mutable_step_stats();
780}
781
782CostGraphDef* NonOwnedProtoRunGraphResponse::mutable_cost_graph() {
783 return response_->mutable_cost_graph();
784}
785
786Status NonOwnedProtoRunGraphResponse::status() const {
787 return Status(response_->status_code(), response_->status_error_message());
788}
789
790errors::Code NonOwnedProtoRunGraphResponse::status_code() const {
791 return response_->status_code();
792}
793
794const string& NonOwnedProtoRunGraphResponse::status_error_message() const {
795 return response_->status_error_message();
796}
797
798void NonOwnedProtoRunGraphResponse::set_status(const Status& status) {
799 response_->set_status_code(status.code());
800 response_->set_status_error_message(status.error_message());
801}
802
803RunGraphResponse* NonOwnedProtoRunGraphResponse::get_proto() {
804 return response_;
805}
806
807size_t NonOwnedProtoRunGraphResponse::num_partition_graphs() const {
808 return response_->partition_graph_size();
809}
810
811GraphDef* NonOwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) {
812 return response_->mutable_partition_graph(i);
813}
814
815void NonOwnedProtoRunGraphResponse::AddPartitionGraph(
816 const GraphDef& partition_graph) {
817 GraphDef* graph_def = response_->add_partition_graph();
818 *graph_def = partition_graph;
819}
820
821MutableRunStepResponseWrapper::~MutableRunStepResponseWrapper() {}
822
823size_t InMemoryRunStepResponse::num_tensors() const { return tensors_.size(); }
824
825const string& InMemoryRunStepResponse::tensor_name(size_t i) const {
826 return tensors_[i].first;
827}
828
829Status InMemoryRunStepResponse::TensorValue(size_t i,
830 Tensor* out_tensor) const {
831 *out_tensor = tensors_[i].second;
832 return OkStatus();
833}
834
835const RunMetadata& InMemoryRunStepResponse::metadata() const {
836 return metadata_;
837}
838
839Status InMemoryRunStepResponse::AddTensorFromRunGraphResponse(
840 const string& name, MutableRunGraphResponseWrapper* wrapper, size_t i) {
841 Tensor tensor;
842 TF_RETURN_IF_ERROR(wrapper->RecvValue(i, &tensor));
843 tensors_.emplace_back(name, tensor);
844 return OkStatus();
845}
846
847RunMetadata* InMemoryRunStepResponse::mutable_metadata() { return &metadata_; }
848
849Status InMemoryRunStepResponse::status() const { return status_; }
850
851errors::Code InMemoryRunStepResponse::status_code() const {
852 return status_.code();
853}
854
855const string& InMemoryRunStepResponse::status_error_message() const {
856 return status_.error_message();
857}
858
859void InMemoryRunStepResponse::set_status(const Status& status) {
860 status_ = status;
861}
862
863RunStepResponse* InMemoryRunStepResponse::get_proto() {
864 LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunStepResponse";
865 return nullptr;
866}
867
868size_t OwnedProtoRunStepResponse::num_tensors() const {
869 return response_.tensor_size();
870}
871
872const string& OwnedProtoRunStepResponse::tensor_name(size_t i) const {
873 return response_.tensor(i).name();
874}
875
876Status OwnedProtoRunStepResponse::TensorValue(size_t i,
877 Tensor* out_tensor) const {
878 if (!ParseTensorProtoToTensor(response_.tensor(i).tensor(), out_tensor)) {
879 return errors::InvalidArgument("Invalid TensorProto for fetch value ", i);
880 } else {
881 return OkStatus();
882 }
883}
884
885const RunMetadata& OwnedProtoRunStepResponse::metadata() const {
886 return response_.metadata();
887}
888
889Status OwnedProtoRunStepResponse::AddTensorFromRunGraphResponse(
890 const string& name, MutableRunGraphResponseWrapper* run_graph_response,
891 size_t i) {
892 NamedTensorProto* response_tensor = response_.add_tensor();
893 response_tensor->set_name(name);
894 return run_graph_response->RecvValue(i, response_tensor->mutable_tensor());
895}
896
897RunMetadata* OwnedProtoRunStepResponse::mutable_metadata() {
898 return response_.mutable_metadata();
899}
900
901Status OwnedProtoRunStepResponse::status() const {
902 return Status(response_.status_code(), response_.status_error_message());
903}
904
905errors::Code OwnedProtoRunStepResponse::status_code() const {
906 return response_.status_code();
907}
908
909const string& OwnedProtoRunStepResponse::status_error_message() const {
910 return response_.status_error_message();
911}
912
913void OwnedProtoRunStepResponse::set_status(const Status& status) {
914 response_.set_status_code(status.code());
915 response_.set_status_error_message(status.error_message());
916}
917
918RunStepResponse* OwnedProtoRunStepResponse::get_proto() { return &response_; }
919
920NonOwnedProtoRunStepResponse::NonOwnedProtoRunStepResponse(
921 RunStepResponse* response)
922 : response_(response) {}
923
924size_t NonOwnedProtoRunStepResponse::num_tensors() const {
925 return response_->tensor_size();
926}
927
928const string& NonOwnedProtoRunStepResponse::tensor_name(size_t i) const {
929 return response_->tensor(i).name();
930}
931
932Status NonOwnedProtoRunStepResponse::TensorValue(size_t i,
933 Tensor* out_tensor) const {
934 if (!ParseTensorProtoToTensor(response_->tensor(i).tensor(), out_tensor)) {
935 return errors::InvalidArgument("Invalid TensorProto for fetch value ", i);
936 } else {
937 return OkStatus();
938 }
939}
940
941const RunMetadata& NonOwnedProtoRunStepResponse::metadata() const {
942 return response_->metadata();
943}
944
945Status NonOwnedProtoRunStepResponse::AddTensorFromRunGraphResponse(
946 const string& name, MutableRunGraphResponseWrapper* run_graph_response,
947 size_t i) {
948 NamedTensorProto* response_tensor = response_->add_tensor();
949 response_tensor->set_name(name);
950 return run_graph_response->RecvValue(i, response_tensor->mutable_tensor());
951}
952
953RunMetadata* NonOwnedProtoRunStepResponse::mutable_metadata() {
954 return response_->mutable_metadata();
955}
956
957Status NonOwnedProtoRunStepResponse::status() const {
958 return Status(response_->status_code(), response_->status_error_message());
959}
960
961errors::Code NonOwnedProtoRunStepResponse::status_code() const {
962 return response_->status_code();
963}
964
965const string& NonOwnedProtoRunStepResponse::status_error_message() const {
966 return response_->status_error_message();
967}
968
969void NonOwnedProtoRunStepResponse::set_status(const Status& status) {
970 response_->set_status_code(status.code());
971 response_->set_status_error_message(status.error_message());
972}
973
974RunStepResponse* NonOwnedProtoRunStepResponse::get_proto() { return response_; }
975
976} // namespace tensorflow
977