1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/distributed_runtime/local_master.h" |
17 | |
18 | #include <unordered_map> |
19 | |
20 | #include "tensorflow/core/distributed_runtime/master.h" |
21 | #include "tensorflow/core/platform/mutex.h" |
22 | |
23 | namespace tensorflow { |
24 | |
25 | namespace { |
26 | Status WaitForNotification(CallOptions* call_options, |
27 | const int64_t default_timeout_in_ms, |
28 | Notification* n) { |
29 | int64_t timeout_in_ms = call_options->GetTimeout(); |
30 | if (timeout_in_ms == 0) { |
31 | timeout_in_ms = default_timeout_in_ms; |
32 | } |
33 | if (timeout_in_ms > 0) { |
34 | int64_t timeout_in_us = timeout_in_ms * 1000; |
35 | bool notified = WaitForNotificationWithTimeout(n, timeout_in_us); |
36 | if (!notified) { |
37 | call_options->StartCancel(); |
38 | // The call has borrowed pointers to the request and response |
39 | // messages, so we must still wait for the call to complete. |
40 | n->WaitForNotification(); |
41 | return errors::DeadlineExceeded("Operation timed out." ); |
42 | } |
43 | } else { |
44 | n->WaitForNotification(); |
45 | } |
46 | return OkStatus(); |
47 | } |
48 | } // namespace |
49 | |
50 | LocalMaster::LocalMaster(Master* master_impl, |
51 | const int64_t default_timeout_in_ms) |
52 | : master_impl_(master_impl), |
53 | default_timeout_in_ms_(default_timeout_in_ms) {} |
54 | |
55 | Status LocalMaster::CreateSession(CallOptions* call_options, |
56 | const CreateSessionRequest* request, |
57 | CreateSessionResponse* response) { |
58 | Notification n; |
59 | Status ret; |
60 | master_impl_->CreateSession(request, response, [&n, &ret](const Status& s) { |
61 | ret.Update(s); |
62 | n.Notify(); |
63 | }); |
64 | TF_RETURN_IF_ERROR( |
65 | WaitForNotification(call_options, default_timeout_in_ms_, &n)); |
66 | return ret; |
67 | } |
68 | |
69 | Status LocalMaster::ExtendSession(CallOptions* call_options, |
70 | const ExtendSessionRequest* request, |
71 | ExtendSessionResponse* response) { |
72 | Notification n; |
73 | Status ret; |
74 | master_impl_->ExtendSession(request, response, [&n, &ret](const Status& s) { |
75 | ret.Update(s); |
76 | n.Notify(); |
77 | }); |
78 | TF_RETURN_IF_ERROR( |
79 | WaitForNotification(call_options, default_timeout_in_ms_, &n)); |
80 | return ret; |
81 | } |
82 | |
83 | Status LocalMaster::PartialRunSetup(CallOptions* call_options, |
84 | const PartialRunSetupRequest* request, |
85 | PartialRunSetupResponse* response) { |
86 | Notification n; |
87 | Status ret; |
88 | master_impl_->PartialRunSetup(request, response, [&n, &ret](const Status& s) { |
89 | ret.Update(s); |
90 | n.Notify(); |
91 | }); |
92 | TF_RETURN_IF_ERROR( |
93 | WaitForNotification(call_options, default_timeout_in_ms_, &n)); |
94 | return ret; |
95 | } |
96 | |
97 | Status LocalMaster::RunStep(CallOptions* call_options, |
98 | RunStepRequestWrapper* request, |
99 | MutableRunStepResponseWrapper* response) { |
100 | Notification n; |
101 | Status ret; |
102 | master_impl_->RunStep(call_options, request, response, |
103 | [&n, &ret](const Status& s) { |
104 | ret.Update(s); |
105 | n.Notify(); |
106 | }); |
107 | TF_RETURN_IF_ERROR( |
108 | WaitForNotification(call_options, default_timeout_in_ms_, &n)); |
109 | return ret; |
110 | } |
111 | |
112 | MutableRunStepRequestWrapper* LocalMaster::CreateRunStepRequest() { |
113 | return new InMemoryRunStepRequest; |
114 | } |
115 | |
116 | MutableRunStepResponseWrapper* LocalMaster::CreateRunStepResponse() { |
117 | return new InMemoryRunStepResponse; |
118 | } |
119 | |
120 | Status LocalMaster::CloseSession(CallOptions* call_options, |
121 | const CloseSessionRequest* request, |
122 | CloseSessionResponse* response) { |
123 | Notification n; |
124 | Status ret; |
125 | master_impl_->CloseSession(request, response, [&n, &ret](const Status& s) { |
126 | ret.Update(s); |
127 | n.Notify(); |
128 | }); |
129 | TF_RETURN_IF_ERROR( |
130 | WaitForNotification(call_options, default_timeout_in_ms_, &n)); |
131 | return ret; |
132 | } |
133 | |
134 | Status LocalMaster::ListDevices(CallOptions* call_options, |
135 | const ListDevicesRequest* request, |
136 | ListDevicesResponse* response) { |
137 | Notification n; |
138 | Status ret; |
139 | master_impl_->ListDevices(request, response, [&n, &ret](const Status& s) { |
140 | ret.Update(s); |
141 | n.Notify(); |
142 | }); |
143 | TF_RETURN_IF_ERROR( |
144 | WaitForNotification(call_options, default_timeout_in_ms_, &n)); |
145 | return ret; |
146 | } |
147 | |
148 | Status LocalMaster::Reset(CallOptions* call_options, |
149 | const ResetRequest* request, |
150 | ResetResponse* response) { |
151 | Notification n; |
152 | Status ret; |
153 | master_impl_->Reset(request, response, [&n, &ret](const Status& s) { |
154 | ret.Update(s); |
155 | n.Notify(); |
156 | }); |
157 | TF_RETURN_IF_ERROR( |
158 | WaitForNotification(call_options, default_timeout_in_ms_, &n)); |
159 | return ret; |
160 | } |
161 | |
162 | Status LocalMaster::MakeCallable(CallOptions* call_options, |
163 | const MakeCallableRequest* request, |
164 | MakeCallableResponse* response) { |
165 | Notification n; |
166 | Status ret; |
167 | master_impl_->MakeCallable(request, response, [&n, &ret](const Status& s) { |
168 | ret.Update(s); |
169 | n.Notify(); |
170 | }); |
171 | TF_RETURN_IF_ERROR( |
172 | WaitForNotification(call_options, default_timeout_in_ms_, &n)); |
173 | return ret; |
174 | } |
175 | Status LocalMaster::RunCallable(CallOptions* call_options, |
176 | const RunCallableRequest* request, |
177 | RunCallableResponse* response) { |
178 | Notification n; |
179 | Status ret; |
180 | master_impl_->RunCallable(call_options, request, response, |
181 | [&n, &ret](const Status& s) { |
182 | ret.Update(s); |
183 | n.Notify(); |
184 | }); |
185 | TF_RETURN_IF_ERROR( |
186 | WaitForNotification(call_options, default_timeout_in_ms_, &n)); |
187 | return ret; |
188 | } |
189 | Status LocalMaster::ReleaseCallable(CallOptions* call_options, |
190 | const ReleaseCallableRequest* request, |
191 | ReleaseCallableResponse* response) { |
192 | Notification n; |
193 | Status ret; |
194 | master_impl_->ReleaseCallable(request, response, [&n, &ret](const Status& s) { |
195 | ret.Update(s); |
196 | n.Notify(); |
197 | }); |
198 | TF_RETURN_IF_ERROR( |
199 | WaitForNotification(call_options, default_timeout_in_ms_, &n)); |
200 | return ret; |
201 | } |
202 | |
203 | namespace { |
204 | mutex* get_local_master_registry_lock() { |
205 | static mutex local_master_registry_lock(LINKER_INITIALIZED); |
206 | return &local_master_registry_lock; |
207 | } |
208 | |
209 | struct MasterInfo { |
210 | Master* master; |
211 | const int64_t default_timeout_in_ms; |
212 | |
213 | MasterInfo(Master* master, const int64_t default_timeout_in_ms) |
214 | : master(master), default_timeout_in_ms(default_timeout_in_ms) {} |
215 | }; |
216 | |
217 | typedef std::unordered_map<string, MasterInfo> LocalMasterRegistry; |
218 | LocalMasterRegistry* local_master_registry() { |
219 | static LocalMasterRegistry* local_master_registry_ = new LocalMasterRegistry; |
220 | return local_master_registry_; |
221 | } |
222 | } // namespace |
223 | |
224 | /* static */ |
225 | void LocalMaster::Register(const string& target, Master* master, |
226 | int64_t default_timeout_in_ms) { |
227 | mutex_lock l(*get_local_master_registry_lock()); |
228 | local_master_registry()->insert( |
229 | {target, MasterInfo(master, default_timeout_in_ms)}); |
230 | } |
231 | |
232 | /* static */ |
233 | std::unique_ptr<LocalMaster> LocalMaster::Lookup(const string& target) { |
234 | std::unique_ptr<LocalMaster> ret; |
235 | mutex_lock l(*get_local_master_registry_lock()); |
236 | auto iter = local_master_registry()->find(target); |
237 | if (iter != local_master_registry()->end()) { |
238 | ret.reset(new LocalMaster(iter->second.master, |
239 | iter->second.default_timeout_in_ms)); |
240 | } |
241 | return ret; |
242 | } |
243 | |
244 | } // namespace tensorflow |
245 | |