1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/common_runtime/rendezvous_util.h" |
16 | #include "tensorflow/core/platform/mutex.h" |
17 | |
18 | #include "tensorflow/core/util/reffed_status_callback.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | Status SendTensorsToRendezvous( |
23 | RendezvousInterface* rendezvous, DeviceContext* device_context, |
24 | const std::vector<AllocatorAttributes>& alloc_attrs, |
25 | const std::vector<string>& keys, gtl::ArraySlice<Tensor> tensors_to_send) { |
26 | if (keys.size() != tensors_to_send.size()) { |
27 | return errors::InvalidArgument( |
28 | "keys and tensors_to_send are not the same size. keys.size() = " , |
29 | keys.size(), "; tensors_to_send.size() = " , tensors_to_send.size()); |
30 | } |
31 | if (!alloc_attrs.empty() && (keys.size() != alloc_attrs.size())) { |
32 | return errors::InvalidArgument( |
33 | "keys and alloc_attrs are not the same size. " , |
34 | "keys.size() = " , keys.size(), |
35 | "; alloc_attrs.size() = " , alloc_attrs.size()); |
36 | } |
37 | |
38 | if (!rendezvous) { |
39 | return errors::InvalidArgument("Rendezvous is null." ); |
40 | } |
41 | |
42 | Rendezvous::ParsedKey parsed; |
43 | for (int i = 0; i < keys.size(); ++i) { |
44 | Rendezvous::Args rendez_args; |
45 | rendez_args.device_context = device_context; |
46 | if (!alloc_attrs.empty()) { |
47 | rendez_args.alloc_attrs = alloc_attrs[i]; |
48 | } |
49 | TF_RETURN_IF_ERROR(Rendezvous::ParseKey(keys[i], &parsed)); |
50 | TF_RETURN_IF_ERROR( |
51 | rendezvous->Send(parsed, rendez_args, tensors_to_send[i], false)); |
52 | } |
53 | return OkStatus(); |
54 | } |
55 | |
56 | void RecvOutputsFromRendezvousAsync( |
57 | RendezvousInterface* rendezvous, DeviceContext* device_context, |
58 | const std::vector<AllocatorAttributes>& alloc_attrs, |
59 | const std::vector<string>& keys, std::vector<Tensor>* received_tensors, |
60 | StatusCallback done) { |
61 | if (keys.empty()) { |
62 | done(OkStatus()); |
63 | return; |
64 | } |
65 | if (!alloc_attrs.empty() && (keys.size() != alloc_attrs.size())) { |
66 | done(errors::InvalidArgument( |
67 | "keys and alloc_attrs are not the same size. " , "keys.size() = " , |
68 | keys.size(), "; alloc_attrs.size() = " , alloc_attrs.size())); |
69 | } |
70 | |
71 | received_tensors->reserve(keys.size()); |
72 | std::vector< |
73 | std::tuple<string, Tensor*, Rendezvous::ParsedKey, AllocatorAttributes>> |
74 | arguments; |
75 | for (int i = 0; i < keys.size(); ++i) { |
76 | Rendezvous::ParsedKey parsed; |
77 | Status s = Rendezvous::ParseKey(keys[i], &parsed); |
78 | received_tensors->push_back(Tensor()); |
79 | if (!s.ok()) { |
80 | done(s); |
81 | return; |
82 | } |
83 | AllocatorAttributes alloc_attr; |
84 | if (!alloc_attrs.empty()) { |
85 | alloc_attr = alloc_attrs[i]; |
86 | } |
87 | arguments.emplace_back(keys[i], &((*received_tensors)[i]), parsed, |
88 | alloc_attr); |
89 | } |
90 | |
91 | auto status_cb = new ReffedStatusCallback(std::move(done)); |
92 | for (auto& p : arguments) { |
93 | const string& key = std::get<0>(p); |
94 | Tensor* val = std::get<1>(p); |
95 | Rendezvous::ParsedKey parsed = std::get<2>(p); |
96 | Rendezvous::Args rendez_args; |
97 | rendez_args.device_context = device_context; |
98 | rendez_args.alloc_attrs = std::get<3>(p); |
99 | status_cb->Ref(); |
100 | rendezvous->RecvAsync( |
101 | parsed, rendez_args, |
102 | [val, key, status_cb](const Status& s, |
103 | const Rendezvous::Args& send_args, |
104 | const Rendezvous::Args& recv_args, |
105 | const Tensor& v, const bool is_dead) { |
106 | Status status = s; |
107 | if (status.ok()) { |
108 | *val = v; |
109 | if (is_dead) { |
110 | status = errors::InvalidArgument("The tensor returned for " , key, |
111 | " was not valid." ); |
112 | } |
113 | } |
114 | status_cb->UpdateStatus(status); |
115 | status_cb->Unref(); |
116 | }); |
117 | } |
118 | status_cb->Unref(); |
119 | } |
120 | |
121 | Status RecvOutputsFromRendezvous(RendezvousInterface* rendezvous, |
122 | NamedTensors* out, |
123 | const Rendezvous::Args& args) { |
124 | // Receives values requested by the caller. |
125 | Rendezvous::ParsedKey parsed; |
126 | for (auto& p : *out) { |
127 | const string& key = p.first; |
128 | Tensor* val = &p.second; |
129 | bool is_dead = false; |
130 | TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed)); |
131 | TF_RETURN_IF_ERROR(rendezvous->Recv(parsed, args, val, &is_dead)); |
132 | if (is_dead) { |
133 | return errors::InvalidArgument("The tensor returned for " , key, |
134 | " was not valid." ); |
135 | } |
136 | } |
137 | return OkStatus(); |
138 | } |
139 | |
140 | } // namespace tensorflow |
141 | |