1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/distributed_runtime/message_wrappers.h" |
17 | |
18 | #include "tensorflow/core/framework/cost_graph.pb.h" |
19 | #include "tensorflow/core/framework/step_stats.pb.h" |
20 | #include "tensorflow/core/framework/tensor.pb.h" |
21 | #include "tensorflow/core/protobuf/config.pb.h" |
22 | #include "tensorflow/core/protobuf/named_tensor.pb.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | bool ParseTensorProtoToTensor(const TensorProto& tensor_proto, |
27 | Tensor* out_tensor) { |
28 | if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) { |
29 | Tensor parsed(tensor_proto.dtype()); |
30 | if (parsed.FromProto(cpu_allocator(), tensor_proto)) { |
31 | *out_tensor = parsed; |
32 | return true; |
33 | } |
34 | } |
35 | return false; |
36 | } |
37 | |
38 | const string& InMemoryRunStepRequest::session_handle() const { |
39 | return session_handle_; |
40 | } |
41 | |
42 | void InMemoryRunStepRequest::set_session_handle(const string& handle) { |
43 | session_handle_ = handle; |
44 | } |
45 | |
46 | const string& InMemoryRunStepRequest::partial_run_handle() const { |
47 | return partial_run_handle_; |
48 | } |
49 | |
50 | void InMemoryRunStepRequest::set_partial_run_handle(const string& handle) { |
51 | partial_run_handle_ = handle; |
52 | } |
53 | |
54 | size_t InMemoryRunStepRequest::num_feeds() const { return feeds_.size(); } |
55 | const string& InMemoryRunStepRequest::feed_name(size_t i) const { |
56 | return feeds_[i].first; |
57 | } |
58 | |
59 | Status InMemoryRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const { |
60 | *out_tensor = feeds_[i].second; |
61 | return OkStatus(); |
62 | } |
63 | |
64 | Status InMemoryRunStepRequest::FeedValue(size_t i, |
65 | TensorProto* out_tensor) const { |
66 | feeds_[i].second.AsProtoTensorContent(out_tensor); |
67 | return OkStatus(); |
68 | } |
69 | |
70 | void InMemoryRunStepRequest::add_feed(const string& name, const Tensor& value) { |
71 | feeds_.emplace_back(name, value); |
72 | } |
73 | |
74 | size_t InMemoryRunStepRequest::num_fetches() const { return fetches_.size(); } |
75 | const string& InMemoryRunStepRequest::fetch_name(size_t i) const { |
76 | return fetches_[i]; |
77 | } |
78 | void InMemoryRunStepRequest::add_fetch(const string& name) { |
79 | fetches_.push_back(name); |
80 | } |
81 | |
82 | size_t InMemoryRunStepRequest::num_targets() const { return targets_.size(); } |
83 | const string& InMemoryRunStepRequest::target_name(size_t i) const { |
84 | return targets_[i]; |
85 | } |
86 | void InMemoryRunStepRequest::add_target(const string& name) { |
87 | targets_.push_back(name); |
88 | } |
89 | |
90 | const RunOptions& InMemoryRunStepRequest::options() const { return options_; } |
91 | |
92 | RunOptions* InMemoryRunStepRequest::mutable_options() { return &options_; } |
93 | |
94 | bool InMemoryRunStepRequest::store_errors_in_response_body() const { |
95 | return store_errors_in_response_body_; |
96 | } |
97 | |
98 | int64_t InMemoryRunStepRequest::request_id() const { |
99 | return 0; // no need to track request id for local version. |
100 | } |
101 | |
102 | void InMemoryRunStepRequest::set_store_errors_in_response_body( |
103 | bool store_errors) { |
104 | store_errors_in_response_body_ = store_errors; |
105 | } |
106 | |
107 | string InMemoryRunStepRequest::DebugString() const { |
108 | return ToProto().DebugString(); |
109 | } |
110 | |
111 | const RunStepRequest& InMemoryRunStepRequest::ToProto() const { |
112 | if (!proto_version_) { |
113 | proto_version_.reset(new RunStepRequest); |
114 | proto_version_->set_session_handle(session_handle()); |
115 | proto_version_->set_partial_run_handle(partial_run_handle()); |
116 | for (size_t i = 0; i < num_feeds(); ++i) { |
117 | auto feed = proto_version_->add_feed(); |
118 | feed->set_name(feed_name(i)); |
119 | feeds_[i].second.AsProtoTensorContent(feed->mutable_tensor()); |
120 | } |
121 | for (size_t i = 0; i < num_fetches(); ++i) { |
122 | proto_version_->add_fetch(fetch_name(i)); |
123 | } |
124 | for (size_t i = 0; i < num_targets(); ++i) { |
125 | proto_version_->add_target(target_name(i)); |
126 | } |
127 | *proto_version_->mutable_options() = options(); |
128 | } |
129 | return *proto_version_; |
130 | } |
131 | |
132 | const string& MutableProtoRunStepRequest::session_handle() const { |
133 | return request_.session_handle(); |
134 | } |
135 | void MutableProtoRunStepRequest::set_session_handle(const string& handle) { |
136 | request_.set_session_handle(handle); |
137 | } |
138 | |
139 | const string& MutableProtoRunStepRequest::partial_run_handle() const { |
140 | return request_.partial_run_handle(); |
141 | } |
142 | void MutableProtoRunStepRequest::set_partial_run_handle(const string& handle) { |
143 | request_.set_partial_run_handle(handle); |
144 | } |
145 | |
146 | size_t MutableProtoRunStepRequest::num_feeds() const { |
147 | return request_.feed_size(); |
148 | } |
149 | const string& MutableProtoRunStepRequest::feed_name(size_t i) const { |
150 | return request_.feed(i).name(); |
151 | } |
152 | Status MutableProtoRunStepRequest::FeedValue(size_t i, |
153 | Tensor* out_tensor) const { |
154 | if (!ParseTensorProtoToTensor(request_.feed(i).tensor(), out_tensor)) { |
155 | return errors::InvalidArgument("Invalid TensorProto for feed value " , i); |
156 | } else { |
157 | return OkStatus(); |
158 | } |
159 | } |
160 | |
161 | Status MutableProtoRunStepRequest::FeedValue(size_t i, |
162 | TensorProto* out_tensor) const { |
163 | *out_tensor = request_.feed(i).tensor(); |
164 | return OkStatus(); |
165 | } |
166 | |
167 | void MutableProtoRunStepRequest::add_feed(const string& name, |
168 | const Tensor& value) { |
169 | NamedTensorProto* feed = request_.add_feed(); |
170 | feed->set_name(name); |
171 | TensorProto* value_proto = feed->mutable_tensor(); |
172 | value.AsProtoTensorContent(value_proto); |
173 | } |
174 | |
175 | size_t MutableProtoRunStepRequest::num_fetches() const { |
176 | return request_.fetch_size(); |
177 | } |
178 | |
179 | const string& MutableProtoRunStepRequest::fetch_name(size_t i) const { |
180 | return request_.fetch(i); |
181 | } |
182 | void MutableProtoRunStepRequest::add_fetch(const string& name) { |
183 | request_.add_fetch(name); |
184 | } |
185 | |
186 | size_t MutableProtoRunStepRequest::num_targets() const { |
187 | return request_.target_size(); |
188 | } |
189 | |
190 | const string& MutableProtoRunStepRequest::target_name(size_t i) const { |
191 | return request_.target(i); |
192 | } |
193 | |
194 | void MutableProtoRunStepRequest::add_target(const string& name) { |
195 | request_.add_target(name); |
196 | } |
197 | |
198 | const RunOptions& MutableProtoRunStepRequest::options() const { |
199 | return request_.options(); |
200 | } |
201 | |
202 | RunOptions* MutableProtoRunStepRequest::mutable_options() { |
203 | return request_.mutable_options(); |
204 | } |
205 | |
206 | bool MutableProtoRunStepRequest::store_errors_in_response_body() const { |
207 | return request_.store_errors_in_response_body(); |
208 | } |
209 | |
210 | void MutableProtoRunStepRequest::set_store_errors_in_response_body( |
211 | bool store_errors) { |
212 | request_.set_store_errors_in_response_body(store_errors); |
213 | } |
214 | |
215 | int64_t MutableProtoRunStepRequest::request_id() const { |
216 | return request_.request_id(); |
217 | } |
218 | |
219 | string MutableProtoRunStepRequest::DebugString() const { |
220 | return request_.DebugString(); |
221 | } |
222 | |
223 | const RunStepRequest& MutableProtoRunStepRequest::ToProto() const { |
224 | return request_; |
225 | } |
226 | |
227 | ProtoRunStepRequest::ProtoRunStepRequest(const RunStepRequest* request) |
228 | : request_(request) {} |
229 | |
230 | const string& ProtoRunStepRequest::session_handle() const { |
231 | return request_->session_handle(); |
232 | } |
233 | |
234 | const string& ProtoRunStepRequest::partial_run_handle() const { |
235 | return request_->partial_run_handle(); |
236 | } |
237 | |
238 | size_t ProtoRunStepRequest::num_feeds() const { return request_->feed_size(); } |
239 | |
240 | const string& ProtoRunStepRequest::feed_name(size_t i) const { |
241 | return request_->feed(i).name(); |
242 | } |
243 | |
244 | Status ProtoRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const { |
245 | if (!ParseTensorProtoToTensor(request_->feed(i).tensor(), out_tensor)) { |
246 | return errors::InvalidArgument("Invalid TensorProto for feed value " , i); |
247 | } else { |
248 | return OkStatus(); |
249 | } |
250 | } |
251 | |
252 | Status ProtoRunStepRequest::FeedValue(size_t i, TensorProto* out_tensor) const { |
253 | *out_tensor = request_->feed(i).tensor(); |
254 | return OkStatus(); |
255 | } |
256 | |
257 | size_t ProtoRunStepRequest::num_fetches() const { |
258 | return request_->fetch_size(); |
259 | } |
260 | |
261 | const string& ProtoRunStepRequest::fetch_name(size_t i) const { |
262 | return request_->fetch(i); |
263 | } |
264 | |
265 | size_t ProtoRunStepRequest::num_targets() const { |
266 | return request_->target_size(); |
267 | } |
268 | |
269 | const string& ProtoRunStepRequest::target_name(size_t i) const { |
270 | return request_->target(i); |
271 | } |
272 | |
273 | const RunOptions& ProtoRunStepRequest::options() const { |
274 | return request_->options(); |
275 | } |
276 | |
277 | bool ProtoRunStepRequest::store_errors_in_response_body() const { |
278 | return request_->store_errors_in_response_body(); |
279 | } |
280 | |
281 | int64_t ProtoRunStepRequest::request_id() const { |
282 | return request_->request_id(); |
283 | } |
284 | |
285 | string ProtoRunStepRequest::DebugString() const { |
286 | return request_->DebugString(); |
287 | } |
288 | |
289 | const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; } |
290 | |
291 | const string& InMemoryRunGraphRequest::session_handle() const { |
292 | return session_handle_; |
293 | } |
294 | |
295 | bool InMemoryRunGraphRequest::create_worker_session_called() const { |
296 | return create_worker_session_called_; |
297 | } |
298 | |
299 | void InMemoryRunGraphRequest::set_session_handle(const string& handle) { |
300 | session_handle_ = handle; |
301 | } |
302 | |
303 | void InMemoryRunGraphRequest::set_create_worker_session_called(bool called) { |
304 | create_worker_session_called_ = called; |
305 | } |
306 | |
307 | const string& InMemoryRunGraphRequest::graph_handle() const { |
308 | return graph_handle_; |
309 | } |
310 | |
311 | void InMemoryRunGraphRequest::set_graph_handle(const string& handle) { |
312 | graph_handle_ = handle; |
313 | } |
314 | |
315 | int64_t InMemoryRunGraphRequest::step_id() const { return step_id_; } |
316 | |
317 | void InMemoryRunGraphRequest::set_step_id(int64_t step_id) { |
318 | step_id_ = step_id; |
319 | } |
320 | |
321 | const ExecutorOpts& InMemoryRunGraphRequest::exec_opts() const { |
322 | return exec_opts_; |
323 | } |
324 | |
325 | ExecutorOpts* InMemoryRunGraphRequest::mutable_exec_opts() { |
326 | return &exec_opts_; |
327 | } |
328 | |
329 | size_t InMemoryRunGraphRequest::num_sends() const { return sends_.size(); } |
330 | |
331 | const string& InMemoryRunGraphRequest::send_key(size_t i) const { |
332 | return sends_[i].first; |
333 | } |
334 | |
335 | Status InMemoryRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const { |
336 | *out_tensor = sends_[i].second; |
337 | return OkStatus(); |
338 | } |
339 | |
340 | Status InMemoryRunGraphRequest::AddSendFromRunStepRequest( |
341 | const RunStepRequestWrapper& run_step_request, size_t i, |
342 | const string& send_key) { |
343 | Tensor tensor; |
344 | TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, &tensor)); |
345 | sends_.emplace_back(send_key, std::move(tensor)); |
346 | return OkStatus(); |
347 | } |
348 | |
349 | Status InMemoryRunGraphRequest::AddSendFromRunCallableRequest( |
350 | const RunCallableRequest& run_callable_request, size_t i, |
351 | const string& send_key) { |
352 | Tensor tensor; |
353 | if (!ParseTensorProtoToTensor(run_callable_request.feed(i), &tensor)) { |
354 | return errors::InvalidArgument("Invalid TensorProto for feed value " , i); |
355 | } |
356 | sends_.emplace_back(send_key, std::move(tensor)); |
357 | return OkStatus(); |
358 | } |
359 | |
360 | size_t InMemoryRunGraphRequest::num_recvs() const { return recvs_.size(); } |
361 | |
362 | const string& InMemoryRunGraphRequest::recv_key(size_t i) const { |
363 | return recvs_[i]; |
364 | } |
365 | |
366 | void InMemoryRunGraphRequest::add_recv_key(const string& recv_key) { |
367 | recvs_.push_back(recv_key); |
368 | } |
369 | |
370 | bool InMemoryRunGraphRequest::is_partial() const { return is_partial_; } |
371 | |
372 | void InMemoryRunGraphRequest::set_is_partial(bool is_partial) { |
373 | is_partial_ = is_partial; |
374 | } |
375 | |
376 | bool InMemoryRunGraphRequest::is_last_partial_run() const { |
377 | return is_last_partial_run_; |
378 | } |
379 | |
380 | void InMemoryRunGraphRequest::set_is_last_partial_run( |
381 | bool is_last_partial_run) { |
382 | is_last_partial_run_ = is_last_partial_run; |
383 | } |
384 | |
385 | bool InMemoryRunGraphRequest::store_errors_in_response_body() const { |
386 | return store_errors_in_response_body_; |
387 | } |
388 | |
389 | void InMemoryRunGraphRequest::set_store_errors_in_response_body( |
390 | bool store_errors) { |
391 | store_errors_in_response_body_ = store_errors; |
392 | } |
393 | |
394 | int64_t InMemoryRunGraphRequest::request_id() const { return request_id_; } |
395 | |
396 | void InMemoryRunGraphRequest::set_request_id(int64_t request_id) { |
397 | request_id_ = request_id; |
398 | } |
399 | |
400 | const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const { |
401 | if (!proto_version_) { |
402 | proto_version_.reset(new RunGraphRequest); |
403 | proto_version_->set_session_handle(session_handle()); |
404 | proto_version_->set_create_worker_session_called( |
405 | create_worker_session_called()); |
406 | proto_version_->set_graph_handle(graph_handle()); |
407 | proto_version_->set_step_id(step_id()); |
408 | *proto_version_->mutable_exec_opts() = exec_opts(); |
409 | for (size_t i = 0; i < num_sends(); ++i) { |
410 | auto send = proto_version_->add_send(); |
411 | send->set_name(send_key(i)); |
412 | sends_[i].second.AsProtoTensorContent(send->mutable_tensor()); |
413 | } |
414 | for (size_t i = 0; i < num_recvs(); ++i) { |
415 | proto_version_->add_recv_key(recv_key(i)); |
416 | } |
417 | proto_version_->set_is_partial(is_partial()); |
418 | proto_version_->set_is_last_partial_run(is_last_partial_run()); |
419 | } |
420 | proto_version_->set_store_errors_in_response_body( |
421 | store_errors_in_response_body_); |
422 | proto_version_->set_request_id(request_id_); |
423 | return *proto_version_; |
424 | } |
425 | |
426 | const string& MutableProtoRunGraphRequest::session_handle() const { |
427 | return request_.session_handle(); |
428 | } |
429 | |
430 | void MutableProtoRunGraphRequest::set_session_handle(const string& handle) { |
431 | request_.set_session_handle(handle); |
432 | } |
433 | |
434 | bool MutableProtoRunGraphRequest::create_worker_session_called() const { |
435 | return request_.create_worker_session_called(); |
436 | } |
437 | |
438 | void MutableProtoRunGraphRequest::set_create_worker_session_called( |
439 | bool called) { |
440 | request_.set_create_worker_session_called(called); |
441 | } |
442 | |
443 | const string& MutableProtoRunGraphRequest::graph_handle() const { |
444 | return request_.graph_handle(); |
445 | } |
446 | |
447 | void MutableProtoRunGraphRequest::set_graph_handle(const string& handle) { |
448 | request_.set_graph_handle(handle); |
449 | } |
450 | |
451 | int64_t MutableProtoRunGraphRequest::step_id() const { |
452 | return request_.step_id(); |
453 | } |
454 | |
455 | void MutableProtoRunGraphRequest::set_step_id(int64_t step_id) { |
456 | request_.set_step_id(step_id); |
457 | } |
458 | |
459 | const ExecutorOpts& MutableProtoRunGraphRequest::exec_opts() const { |
460 | return request_.exec_opts(); |
461 | } |
462 | |
463 | ExecutorOpts* MutableProtoRunGraphRequest::mutable_exec_opts() { |
464 | return request_.mutable_exec_opts(); |
465 | } |
466 | |
467 | size_t MutableProtoRunGraphRequest::num_sends() const { |
468 | return request_.send_size(); |
469 | } |
470 | |
471 | const string& MutableProtoRunGraphRequest::send_key(size_t i) const { |
472 | return request_.send(i).name(); |
473 | } |
474 | |
475 | Status MutableProtoRunGraphRequest::SendValue(size_t i, |
476 | Tensor* out_tensor) const { |
477 | if (!ParseTensorProtoToTensor(request_.send(i).tensor(), out_tensor)) { |
478 | return errors::InvalidArgument("Invalid TensorProto for feed value " , i); |
479 | } else { |
480 | return OkStatus(); |
481 | } |
482 | } |
483 | |
484 | Status MutableProtoRunGraphRequest::AddSendFromRunStepRequest( |
485 | const RunStepRequestWrapper& run_step_request, size_t i, |
486 | const string& send_key) { |
487 | NamedTensorProto* send = request_.add_send(); |
488 | send->set_name(send_key); |
489 | TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, send->mutable_tensor())); |
490 | return OkStatus(); |
491 | } |
492 | |
493 | Status MutableProtoRunGraphRequest::AddSendFromRunCallableRequest( |
494 | const RunCallableRequest& run_callable_request, size_t i, |
495 | const string& send_key) { |
496 | NamedTensorProto* send = request_.add_send(); |
497 | send->set_name(send_key); |
498 | *send->mutable_tensor() = run_callable_request.feed(i); |
499 | return OkStatus(); |
500 | } |
501 | |
502 | size_t MutableProtoRunGraphRequest::num_recvs() const { |
503 | return request_.recv_key_size(); |
504 | } |
505 | |
506 | const string& MutableProtoRunGraphRequest::recv_key(size_t i) const { |
507 | return request_.recv_key(i); |
508 | } |
509 | |
510 | void MutableProtoRunGraphRequest::add_recv_key(const string& recv_key) { |
511 | request_.add_recv_key(recv_key); |
512 | } |
513 | |
514 | bool MutableProtoRunGraphRequest::is_partial() const { |
515 | return request_.is_partial(); |
516 | } |
517 | |
518 | void MutableProtoRunGraphRequest::set_is_partial(bool is_partial) { |
519 | request_.set_is_partial(is_partial); |
520 | } |
521 | |
522 | bool MutableProtoRunGraphRequest::is_last_partial_run() const { |
523 | return request_.is_last_partial_run(); |
524 | } |
525 | |
526 | void MutableProtoRunGraphRequest::set_is_last_partial_run( |
527 | bool is_last_partial_run) { |
528 | request_.set_is_last_partial_run(is_last_partial_run); |
529 | } |
530 | |
531 | bool MutableProtoRunGraphRequest::store_errors_in_response_body() const { |
532 | return request_.store_errors_in_response_body(); |
533 | } |
534 | |
535 | void MutableProtoRunGraphRequest::set_store_errors_in_response_body( |
536 | bool store_errors) { |
537 | request_.set_store_errors_in_response_body(store_errors); |
538 | } |
539 | |
540 | int64_t MutableProtoRunGraphRequest::request_id() const { |
541 | return request_.request_id(); |
542 | } |
543 | |
544 | void MutableProtoRunGraphRequest::set_request_id(int64_t request_id) { |
545 | request_.set_request_id(request_id); |
546 | } |
547 | |
548 | const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const { |
549 | return request_; |
550 | } |
551 | |
552 | ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request) |
553 | : request_(request) {} |
554 | |
555 | const string& ProtoRunGraphRequest::session_handle() const { |
556 | return request_->session_handle(); |
557 | } |
558 | |
559 | bool ProtoRunGraphRequest::create_worker_session_called() const { |
560 | return request_->create_worker_session_called(); |
561 | } |
562 | |
563 | const string& ProtoRunGraphRequest::graph_handle() const { |
564 | return request_->graph_handle(); |
565 | } |
566 | |
567 | int64_t ProtoRunGraphRequest::step_id() const { return request_->step_id(); } |
568 | |
569 | const ExecutorOpts& ProtoRunGraphRequest::exec_opts() const { |
570 | return request_->exec_opts(); |
571 | } |
572 | |
573 | size_t ProtoRunGraphRequest::num_sends() const { return request_->send_size(); } |
574 | |
575 | const string& ProtoRunGraphRequest::send_key(size_t i) const { |
576 | return request_->send(i).name(); |
577 | } |
578 | |
579 | Status ProtoRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const { |
580 | if (!ParseTensorProtoToTensor(request_->send(i).tensor(), out_tensor)) { |
581 | return errors::InvalidArgument("Invalid TensorProto for feed value " , i); |
582 | } else { |
583 | return OkStatus(); |
584 | } |
585 | } |
586 | |
587 | size_t ProtoRunGraphRequest::num_recvs() const { |
588 | return request_->recv_key_size(); |
589 | } |
590 | |
591 | const string& ProtoRunGraphRequest::recv_key(size_t i) const { |
592 | return request_->recv_key(i); |
593 | } |
594 | |
595 | bool ProtoRunGraphRequest::is_partial() const { return request_->is_partial(); } |
596 | |
597 | bool ProtoRunGraphRequest::is_last_partial_run() const { |
598 | return request_->is_last_partial_run(); |
599 | } |
600 | |
601 | bool ProtoRunGraphRequest::store_errors_in_response_body() const { |
602 | return request_->store_errors_in_response_body(); |
603 | } |
604 | |
605 | int64_t ProtoRunGraphRequest::request_id() const { |
606 | return request_->request_id(); |
607 | } |
608 | |
609 | const RunGraphRequest& ProtoRunGraphRequest::ToProto() const { |
610 | return *request_; |
611 | } |
612 | |
613 | size_t InMemoryRunGraphResponse::num_recvs() const { return recvs_.size(); } |
614 | |
615 | const string& InMemoryRunGraphResponse::recv_key(size_t i) const { |
616 | return recvs_[i].first; |
617 | } |
618 | |
619 | Status InMemoryRunGraphResponse::RecvValue(size_t i, TensorProto* out_tensor) { |
620 | recvs_[i].second.AsProtoTensorContent(out_tensor); |
621 | return OkStatus(); |
622 | } |
623 | |
624 | Status InMemoryRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) { |
625 | *out_tensor = recvs_[i].second; |
626 | return OkStatus(); |
627 | } |
628 | |
629 | void InMemoryRunGraphResponse::AddRecv(const string& key, const Tensor& value) { |
630 | recvs_.emplace_back(key, value); |
631 | } |
632 | |
633 | StepStats* InMemoryRunGraphResponse::mutable_step_stats() { |
634 | return &step_stats_; |
635 | } |
636 | |
637 | CostGraphDef* InMemoryRunGraphResponse::mutable_cost_graph() { |
638 | return &cost_graph_; |
639 | } |
640 | |
641 | Status InMemoryRunGraphResponse::status() const { return status_; } |
642 | |
643 | errors::Code InMemoryRunGraphResponse::status_code() const { |
644 | return status_.code(); |
645 | } |
646 | |
647 | const string& InMemoryRunGraphResponse::status_error_message() const { |
648 | return status_.error_message(); |
649 | } |
650 | |
651 | void InMemoryRunGraphResponse::set_status(const Status& status) { |
652 | status_ = status; |
653 | } |
654 | |
655 | RunGraphResponse* InMemoryRunGraphResponse::get_proto() { |
656 | LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunGraphResponse" ; |
657 | return nullptr; |
658 | } |
659 | |
660 | size_t InMemoryRunGraphResponse::num_partition_graphs() const { |
661 | return partition_graphs_.size(); |
662 | } |
663 | |
664 | GraphDef* InMemoryRunGraphResponse::mutable_partition_graph(size_t i) { |
665 | return &partition_graphs_[i]; |
666 | } |
667 | |
668 | void InMemoryRunGraphResponse::AddPartitionGraph( |
669 | const GraphDef& partition_graph) { |
670 | partition_graphs_.push_back(partition_graph); |
671 | } |
672 | |
673 | size_t OwnedProtoRunGraphResponse::num_recvs() const { |
674 | return response_.recv_size(); |
675 | } |
676 | |
677 | const string& OwnedProtoRunGraphResponse::recv_key(size_t i) const { |
678 | return response_.recv(i).name(); |
679 | } |
680 | |
681 | Status OwnedProtoRunGraphResponse::RecvValue(size_t i, |
682 | TensorProto* out_tensor) { |
683 | out_tensor->Swap(response_.mutable_recv(i)->mutable_tensor()); |
684 | return OkStatus(); |
685 | } |
686 | |
687 | Status OwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) { |
688 | if (!ParseTensorProtoToTensor(response_.recv(i).tensor(), out_tensor)) { |
689 | return errors::InvalidArgument("Invalid TensorProto for recv value " , i); |
690 | } else { |
691 | return OkStatus(); |
692 | } |
693 | } |
694 | |
695 | void OwnedProtoRunGraphResponse::AddRecv(const string& key, |
696 | const Tensor& value) { |
697 | NamedTensorProto* recv = response_.add_recv(); |
698 | recv->set_name(key); |
699 | TensorProto* value_proto = recv->mutable_tensor(); |
700 | value.AsProtoTensorContent(value_proto); |
701 | } |
702 | |
703 | StepStats* OwnedProtoRunGraphResponse::mutable_step_stats() { |
704 | return response_.mutable_step_stats(); |
705 | } |
706 | |
707 | CostGraphDef* OwnedProtoRunGraphResponse::mutable_cost_graph() { |
708 | return response_.mutable_cost_graph(); |
709 | } |
710 | |
711 | Status OwnedProtoRunGraphResponse::status() const { |
712 | return Status(response_.status_code(), response_.status_error_message()); |
713 | } |
714 | |
715 | errors::Code OwnedProtoRunGraphResponse::status_code() const { |
716 | return response_.status_code(); |
717 | } |
718 | |
719 | const string& OwnedProtoRunGraphResponse::status_error_message() const { |
720 | return response_.status_error_message(); |
721 | } |
722 | |
723 | void OwnedProtoRunGraphResponse::set_status(const Status& status) { |
724 | response_.set_status_code(status.code()); |
725 | response_.set_status_error_message(status.error_message()); |
726 | } |
727 | |
728 | RunGraphResponse* OwnedProtoRunGraphResponse::get_proto() { return &response_; } |
729 | |
730 | size_t OwnedProtoRunGraphResponse::num_partition_graphs() const { |
731 | return response_.partition_graph_size(); |
732 | } |
733 | |
734 | GraphDef* OwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) { |
735 | return response_.mutable_partition_graph(i); |
736 | } |
737 | |
738 | void OwnedProtoRunGraphResponse::AddPartitionGraph( |
739 | const GraphDef& partition_graph) { |
740 | GraphDef* graph_def = response_.mutable_partition_graph()->Add(); |
741 | *graph_def = partition_graph; |
742 | } |
743 | |
744 | NonOwnedProtoRunGraphResponse::NonOwnedProtoRunGraphResponse( |
745 | RunGraphResponse* response) |
746 | : response_(response) {} |
747 | |
748 | size_t NonOwnedProtoRunGraphResponse::num_recvs() const { |
749 | return response_->recv_size(); |
750 | } |
751 | |
752 | const string& NonOwnedProtoRunGraphResponse::recv_key(size_t i) const { |
753 | return response_->recv(i).name(); |
754 | } |
755 | |
756 | Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i, |
757 | TensorProto* out_tensor) { |
758 | out_tensor->Swap(response_->mutable_recv(i)->mutable_tensor()); |
759 | return OkStatus(); |
760 | } |
761 | |
762 | Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) { |
763 | if (!ParseTensorProtoToTensor(response_->recv(i).tensor(), out_tensor)) { |
764 | return errors::InvalidArgument("Invalid TensorProto for recv value " , i); |
765 | } else { |
766 | return OkStatus(); |
767 | } |
768 | } |
769 | |
770 | void NonOwnedProtoRunGraphResponse::AddRecv(const string& key, |
771 | const Tensor& value) { |
772 | NamedTensorProto* recv = response_->add_recv(); |
773 | recv->set_name(key); |
774 | TensorProto* value_proto = recv->mutable_tensor(); |
775 | value.AsProtoTensorContent(value_proto); |
776 | } |
777 | |
778 | StepStats* NonOwnedProtoRunGraphResponse::mutable_step_stats() { |
779 | return response_->mutable_step_stats(); |
780 | } |
781 | |
782 | CostGraphDef* NonOwnedProtoRunGraphResponse::mutable_cost_graph() { |
783 | return response_->mutable_cost_graph(); |
784 | } |
785 | |
786 | Status NonOwnedProtoRunGraphResponse::status() const { |
787 | return Status(response_->status_code(), response_->status_error_message()); |
788 | } |
789 | |
790 | errors::Code NonOwnedProtoRunGraphResponse::status_code() const { |
791 | return response_->status_code(); |
792 | } |
793 | |
794 | const string& NonOwnedProtoRunGraphResponse::status_error_message() const { |
795 | return response_->status_error_message(); |
796 | } |
797 | |
798 | void NonOwnedProtoRunGraphResponse::set_status(const Status& status) { |
799 | response_->set_status_code(status.code()); |
800 | response_->set_status_error_message(status.error_message()); |
801 | } |
802 | |
803 | RunGraphResponse* NonOwnedProtoRunGraphResponse::get_proto() { |
804 | return response_; |
805 | } |
806 | |
807 | size_t NonOwnedProtoRunGraphResponse::num_partition_graphs() const { |
808 | return response_->partition_graph_size(); |
809 | } |
810 | |
811 | GraphDef* NonOwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) { |
812 | return response_->mutable_partition_graph(i); |
813 | } |
814 | |
815 | void NonOwnedProtoRunGraphResponse::AddPartitionGraph( |
816 | const GraphDef& partition_graph) { |
817 | GraphDef* graph_def = response_->add_partition_graph(); |
818 | *graph_def = partition_graph; |
819 | } |
820 | |
821 | MutableRunStepResponseWrapper::~MutableRunStepResponseWrapper() {} |
822 | |
823 | size_t InMemoryRunStepResponse::num_tensors() const { return tensors_.size(); } |
824 | |
825 | const string& InMemoryRunStepResponse::tensor_name(size_t i) const { |
826 | return tensors_[i].first; |
827 | } |
828 | |
829 | Status InMemoryRunStepResponse::TensorValue(size_t i, |
830 | Tensor* out_tensor) const { |
831 | *out_tensor = tensors_[i].second; |
832 | return OkStatus(); |
833 | } |
834 | |
835 | const RunMetadata& InMemoryRunStepResponse::metadata() const { |
836 | return metadata_; |
837 | } |
838 | |
839 | Status InMemoryRunStepResponse::AddTensorFromRunGraphResponse( |
840 | const string& name, MutableRunGraphResponseWrapper* wrapper, size_t i) { |
841 | Tensor tensor; |
842 | TF_RETURN_IF_ERROR(wrapper->RecvValue(i, &tensor)); |
843 | tensors_.emplace_back(name, tensor); |
844 | return OkStatus(); |
845 | } |
846 | |
847 | RunMetadata* InMemoryRunStepResponse::mutable_metadata() { return &metadata_; } |
848 | |
849 | Status InMemoryRunStepResponse::status() const { return status_; } |
850 | |
851 | errors::Code InMemoryRunStepResponse::status_code() const { |
852 | return status_.code(); |
853 | } |
854 | |
855 | const string& InMemoryRunStepResponse::status_error_message() const { |
856 | return status_.error_message(); |
857 | } |
858 | |
859 | void InMemoryRunStepResponse::set_status(const Status& status) { |
860 | status_ = status; |
861 | } |
862 | |
863 | RunStepResponse* InMemoryRunStepResponse::get_proto() { |
864 | LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunStepResponse" ; |
865 | return nullptr; |
866 | } |
867 | |
868 | size_t OwnedProtoRunStepResponse::num_tensors() const { |
869 | return response_.tensor_size(); |
870 | } |
871 | |
872 | const string& OwnedProtoRunStepResponse::tensor_name(size_t i) const { |
873 | return response_.tensor(i).name(); |
874 | } |
875 | |
876 | Status OwnedProtoRunStepResponse::TensorValue(size_t i, |
877 | Tensor* out_tensor) const { |
878 | if (!ParseTensorProtoToTensor(response_.tensor(i).tensor(), out_tensor)) { |
879 | return errors::InvalidArgument("Invalid TensorProto for fetch value " , i); |
880 | } else { |
881 | return OkStatus(); |
882 | } |
883 | } |
884 | |
885 | const RunMetadata& OwnedProtoRunStepResponse::metadata() const { |
886 | return response_.metadata(); |
887 | } |
888 | |
889 | Status OwnedProtoRunStepResponse::AddTensorFromRunGraphResponse( |
890 | const string& name, MutableRunGraphResponseWrapper* run_graph_response, |
891 | size_t i) { |
892 | NamedTensorProto* response_tensor = response_.add_tensor(); |
893 | response_tensor->set_name(name); |
894 | return run_graph_response->RecvValue(i, response_tensor->mutable_tensor()); |
895 | } |
896 | |
897 | RunMetadata* OwnedProtoRunStepResponse::mutable_metadata() { |
898 | return response_.mutable_metadata(); |
899 | } |
900 | |
901 | Status OwnedProtoRunStepResponse::status() const { |
902 | return Status(response_.status_code(), response_.status_error_message()); |
903 | } |
904 | |
905 | errors::Code OwnedProtoRunStepResponse::status_code() const { |
906 | return response_.status_code(); |
907 | } |
908 | |
909 | const string& OwnedProtoRunStepResponse::status_error_message() const { |
910 | return response_.status_error_message(); |
911 | } |
912 | |
913 | void OwnedProtoRunStepResponse::set_status(const Status& status) { |
914 | response_.set_status_code(status.code()); |
915 | response_.set_status_error_message(status.error_message()); |
916 | } |
917 | |
918 | RunStepResponse* OwnedProtoRunStepResponse::get_proto() { return &response_; } |
919 | |
920 | NonOwnedProtoRunStepResponse::NonOwnedProtoRunStepResponse( |
921 | RunStepResponse* response) |
922 | : response_(response) {} |
923 | |
924 | size_t NonOwnedProtoRunStepResponse::num_tensors() const { |
925 | return response_->tensor_size(); |
926 | } |
927 | |
928 | const string& NonOwnedProtoRunStepResponse::tensor_name(size_t i) const { |
929 | return response_->tensor(i).name(); |
930 | } |
931 | |
932 | Status NonOwnedProtoRunStepResponse::TensorValue(size_t i, |
933 | Tensor* out_tensor) const { |
934 | if (!ParseTensorProtoToTensor(response_->tensor(i).tensor(), out_tensor)) { |
935 | return errors::InvalidArgument("Invalid TensorProto for fetch value " , i); |
936 | } else { |
937 | return OkStatus(); |
938 | } |
939 | } |
940 | |
941 | const RunMetadata& NonOwnedProtoRunStepResponse::metadata() const { |
942 | return response_->metadata(); |
943 | } |
944 | |
945 | Status NonOwnedProtoRunStepResponse::AddTensorFromRunGraphResponse( |
946 | const string& name, MutableRunGraphResponseWrapper* run_graph_response, |
947 | size_t i) { |
948 | NamedTensorProto* response_tensor = response_->add_tensor(); |
949 | response_tensor->set_name(name); |
950 | return run_graph_response->RecvValue(i, response_tensor->mutable_tensor()); |
951 | } |
952 | |
953 | RunMetadata* NonOwnedProtoRunStepResponse::mutable_metadata() { |
954 | return response_->mutable_metadata(); |
955 | } |
956 | |
957 | Status NonOwnedProtoRunStepResponse::status() const { |
958 | return Status(response_->status_code(), response_->status_error_message()); |
959 | } |
960 | |
961 | errors::Code NonOwnedProtoRunStepResponse::status_code() const { |
962 | return response_->status_code(); |
963 | } |
964 | |
965 | const string& NonOwnedProtoRunStepResponse::status_error_message() const { |
966 | return response_->status_error_message(); |
967 | } |
968 | |
969 | void NonOwnedProtoRunStepResponse::set_status(const Status& status) { |
970 | response_->set_status_code(status.code()); |
971 | response_->set_status_error_message(status.error_message()); |
972 | } |
973 | |
974 | RunStepResponse* NonOwnedProtoRunStepResponse::get_proto() { return response_; } |
975 | |
976 | } // namespace tensorflow |
977 | |