1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/local_rendezvous.h" |
17 | |
18 | #include "tensorflow/core/framework/allocator.h" |
19 | #include "tensorflow/core/framework/types.h" |
20 | #include "tensorflow/core/lib/core/errors.h" |
21 | #include "tensorflow/core/lib/core/notification.h" |
22 | #include "tensorflow/core/lib/gtl/manual_constructor.h" |
23 | #include "tensorflow/core/lib/monitoring/counter.h" |
24 | #include "tensorflow/core/lib/strings/numbers.h" |
25 | #include "tensorflow/core/lib/strings/str_util.h" |
26 | #include "tensorflow/core/platform/logging.h" |
27 | #include "tensorflow/core/platform/mutex.h" |
28 | #include "tensorflow/core/platform/refcount.h" |
29 | #include "tensorflow/core/platform/types.h" |
30 | |
31 | namespace tensorflow { |
32 | |
33 | // Represents a blocked Send() or Recv() call in the rendezvous. |
34 | struct LocalRendezvous::Item { |
35 | enum Type { kSend = 0, kRecv = 1 }; |
36 | |
37 | Item(Rendezvous::Args send_args, const Tensor& value, bool is_dead) |
38 | : Item(send_args, kSend) { |
39 | send_state.value.Init(value); |
40 | send_state.is_dead = is_dead; |
41 | } |
42 | |
43 | Item(Rendezvous::Args recv_args, Rendezvous::DoneCallback waiter, |
44 | CancellationToken cancellation_token) |
45 | : Item(recv_args, kRecv) { |
46 | recv_state.waiter.Init(std::move(waiter)); |
47 | recv_state.cancellation_token = cancellation_token; |
48 | } |
49 | |
50 | ~Item() { |
51 | if (args.device_context) { |
52 | args.device_context->Unref(); |
53 | } |
54 | if (type == kSend) { |
55 | send_state.value.Destroy(); |
56 | } else { |
57 | recv_state.waiter.Destroy(); |
58 | } |
59 | } |
60 | |
61 | const Rendezvous::Args args; |
62 | const Type type; |
63 | |
64 | // Link to next item in an ItemQueue. |
65 | Item* next = nullptr; |
66 | |
67 | // The validity of `send_state` or `recv_state` is determined by `type == |
68 | // kSend` or `type == kRecv` respectively. |
69 | union { |
70 | struct { |
71 | ManualConstructor<Tensor> value; |
72 | bool is_dead; |
73 | } send_state; |
74 | struct { |
75 | ManualConstructor<Rendezvous::DoneCallback> waiter; |
76 | CancellationToken cancellation_token; |
77 | } recv_state; |
78 | }; |
79 | |
80 | private: |
81 | Item(Rendezvous::Args args, Type type) : args(args), type(type) { |
82 | if (args.device_context) { |
83 | args.device_context->Ref(); |
84 | } |
85 | } |
86 | }; |
87 | |
88 | void LocalRendezvous::ItemQueue::push_back(Item* item) { |
89 | if (TF_PREDICT_TRUE(head == nullptr)) { |
90 | // The queue is empty. |
91 | head = item; |
92 | tail = item; |
93 | } else { |
94 | DCHECK_EQ(tail->type, item->type); |
95 | tail->next = item; |
96 | tail = item; |
97 | } |
98 | } |
99 | |
100 | LocalRendezvous::~LocalRendezvous() { |
101 | // Before destroying this rendezvous instance, make sure all the done-callback |
102 | // calls have finished and the tensors have been released from the queue. |
103 | { |
104 | mutex_lock l(mu_); |
105 | while (pending_callback_counter_ != 0) { |
106 | pending_callback_cond_var_.wait_for(l, std::chrono::milliseconds(50)); |
107 | } |
108 | } |
109 | |
110 | if (!table_.empty()) { |
111 | StartAbort(errors::Cancelled("LocalRendezvous deleted" )); |
112 | } |
113 | } |
114 | |
115 | namespace { |
116 | uint64 KeyHash(const StringPiece& k) { return Hash64(k.data(), k.size()); } |
117 | } // namespace |
118 | |
119 | Status LocalRendezvous::Send(const Rendezvous::ParsedKey& key, |
120 | const Rendezvous::Args& send_args, |
121 | const Tensor& val, const bool is_dead) { |
122 | uint64 key_hash = KeyHash(key.FullKey()); |
123 | DVLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey(); |
124 | |
125 | if (is_dead) { |
126 | static auto* rendezvous_dead_values_sent = monitoring::Counter<2>::New( |
127 | "/tensorflow/core/rendezvous_dead_values_sent" , |
128 | "The number of dead values sent between a pair of devices." , |
129 | "send_device" , "recv_device" ); |
130 | rendezvous_dead_values_sent |
131 | ->GetCell(string(key.src_device), string(key.dst_device)) |
132 | ->IncrementBy(1); |
133 | } |
134 | |
135 | mu_.lock(); |
136 | if (!status_.ok()) { |
137 | // Rendezvous has been aborted. |
138 | Status s = status_; |
139 | mu_.unlock(); |
140 | return s; |
141 | } |
142 | |
143 | ItemQueue* queue = &table_[key_hash]; |
144 | if (queue->head == nullptr || queue->head->type == Item::kSend) { |
145 | // There is no waiter for this message. Append the message |
146 | // into the queue. The waiter will pick it up when arrives. |
147 | // Only send-related fields need to be filled. |
148 | // TODO(b/143786186): Investigate moving the allocation of `Item` outside |
149 | // the lock. |
150 | DVLOG(2) << "Enqueue Send Item (key:" << key.FullKey() << "). " ; |
151 | queue->push_back(new Item(send_args, val, is_dead)); |
152 | mu_.unlock(); |
153 | return OkStatus(); |
154 | } |
155 | |
156 | DVLOG(2) << "Consume Recv Item (key:" << key.FullKey() << "). " ; |
157 | // There is an earliest waiter to consume this message. |
158 | Item* item = queue->head; |
159 | |
160 | // Delete the queue when the last element has been consumed. |
161 | if (item->next == nullptr) { |
162 | DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). " ; |
163 | table_.erase(key_hash); |
164 | } else { |
165 | queue->head = item->next; |
166 | } |
167 | |
168 | // Make sure the ref-count of the rendezvous won't reach 0 while the |
169 | // done_callback is running, which would otherwise become deadlock: |
170 | // the done_callback waits for the Unref() to return, while the destructor |
171 | // wiats for the pending_callback_counter to reach 0. |
172 | core::RefCountPtr<const Rendezvous> rc_owner_ref; |
173 | if (rc_owner_) { |
174 | rc_owner_ref.reset(rc_owner_); |
175 | rc_owner_->Ref(); |
176 | } |
177 | pending_callback_counter_++; |
178 | // Invoke the done-callback, without holding the lock. |
179 | mu_.unlock(); |
180 | DCHECK_EQ(item->type, Item::kRecv); |
181 | (*item->recv_state.waiter)(OkStatus(), send_args, item->args, val, is_dead); |
182 | delete item; |
183 | { |
184 | mutex_lock l(mu_); |
185 | pending_callback_counter_--; |
186 | if (pending_callback_counter_ == 0) { |
187 | pending_callback_cond_var_.notify_all(); |
188 | } |
189 | } |
190 | return OkStatus(); |
191 | } |
192 | |
193 | void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key, |
194 | const Rendezvous::Args& recv_args, |
195 | Rendezvous::DoneCallback done) { |
196 | uint64 key_hash = KeyHash(key.FullKey()); |
197 | DVLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey(); |
198 | |
199 | mu_.lock(); |
200 | if (!status_.ok()) { |
201 | // Rendezvous has been aborted. |
202 | Status s = status_; |
203 | mu_.unlock(); |
204 | done(s, Rendezvous::Args(), recv_args, Tensor(), false); |
205 | return; |
206 | } |
207 | |
208 | ItemQueue* queue = &table_[key_hash]; |
209 | if (queue->head == nullptr || queue->head->type == Item::kRecv) { |
210 | // There is no message to pick up. |
211 | // Only recv-related fields need to be filled. |
212 | CancellationManager* cm = recv_args.cancellation_manager; |
213 | CancellationToken token = CancellationManager::kInvalidToken; |
214 | bool already_cancelled = false; |
215 | if (cm != nullptr) { |
216 | // Increment the refcount when cancellation manager is present, to make |
217 | // sure the rendezvous outlives the recv and its cancel callbacks. |
218 | // This refcount is dropped in exactly one of the following cases: |
219 | // (1) Recv registers cancellation callback to cm, and then cm is |
220 | // cancelled, unref in the cancellation callback; |
221 | // (2) Recv registers cancellation callback to cm, but cm is already |
222 | // cancelled, unref in the already_cancelled check; |
223 | // (3) Recv is successful, and item done callback finishes deregistering |
224 | // the cancellation callback, unref in the item done callback; |
225 | // (4) Recv is successful, but the item done callback fails to deregister |
226 | // the cancellation callback because cm already StartCancel, in this |
227 | // case the cancellation callback will be invoked by the cm anyway, |
228 | // unref in the cancellation callback. |
229 | if (rc_owner_) rc_owner_->Ref(); |
230 | token = cm->get_cancellation_token(); |
231 | already_cancelled = !cm->RegisterCallback(token, [this, token, key_hash] { |
232 | Item* item = nullptr; |
233 | { |
234 | mutex_lock l(mu_); |
235 | ItemQueue* queue = &table_[key_hash]; |
236 | // Find an item in the queue with a cancellation token that matches |
237 | // `token`, and remove it. |
238 | if (queue->head != nullptr && queue->head->type == Item::kRecv) { |
239 | for (Item *prev = nullptr, *curr = queue->head; curr != nullptr; |
240 | prev = curr, curr = curr->next) { |
241 | if (curr->recv_state.cancellation_token == token) { |
242 | item = curr; |
243 | if (queue->head->next == nullptr) { |
244 | // We have a single-element queue, so we can erase it from |
245 | // the table. |
246 | table_.erase(key_hash); |
247 | } else { |
248 | // Remove the current item from the queue. |
249 | if (curr == queue->head) { |
250 | DCHECK_EQ(prev, nullptr); |
251 | queue->head = curr->next; |
252 | } else { |
253 | DCHECK_NE(prev, nullptr); |
254 | prev->next = curr->next; |
255 | } |
256 | if (queue->tail == curr) { |
257 | queue->tail = prev; |
258 | } |
259 | } |
260 | break; |
261 | } |
262 | } |
263 | } |
264 | } |
265 | |
266 | if (item != nullptr) { |
267 | (*item->recv_state.waiter)( |
268 | StatusGroup::MakeDerived( |
269 | errors::Cancelled("RecvAsync is cancelled." )), |
270 | Rendezvous::Args(), item->args, Tensor(), /*is_dead=*/false); |
271 | delete item; |
272 | } |
273 | // Unref case (1) and (4) |
274 | if (rc_owner_) rc_owner_->Unref(); |
275 | }); |
276 | } |
277 | if (already_cancelled) { |
278 | mu_.unlock(); |
279 | // Unref case (2) |
280 | if (rc_owner_) rc_owner_->Unref(); |
281 | done(StatusGroup::MakeDerived( |
282 | errors::Cancelled("RecvAsync is cancelled." )), |
283 | Rendezvous::Args(), recv_args, Tensor(), /*is_dead=*/false); |
284 | return; |
285 | } |
286 | |
287 | DVLOG(2) << "Enqueue Recv Item (key:" << key.FullKey() << "). " ; |
288 | |
289 | // TODO(b/143786186): Investigate moving the allocation of `Item` outside |
290 | // the lock. |
291 | if (cm != nullptr) { |
292 | // NOTE(mrry): We must wrap `done` with code that deregisters the |
293 | // cancellation callback before calling the `done` callback, because the |
294 | // cancellation manager may no longer be live after `done` is called. |
295 | queue->push_back(new Item( |
296 | recv_args, |
297 | [this, cm, token, done = std::move(done)]( |
298 | const Status& s, const Rendezvous::Args& send_args, |
299 | const Rendezvous::Args& recv_args, const Tensor& v, bool dead) { |
300 | // TryDeregisterCallback returns true when the cancellation callback |
301 | // is successfully deregistered. If it fails because the CM already |
302 | // StartAbort, Unref will happen inside the cancellation callback |
303 | // when called by the CM. |
304 | if (cm->TryDeregisterCallback(token)) { |
305 | // Unref case (3) |
306 | if (this->rc_owner_) this->rc_owner_->Unref(); |
307 | } |
308 | done(s, send_args, recv_args, v, dead); |
309 | }, |
310 | token)); |
311 | } else { |
312 | queue->push_back(new Item(recv_args, std::move(done), token)); |
313 | } |
314 | |
315 | mu_.unlock(); |
316 | return; |
317 | } |
318 | |
319 | DVLOG(2) << "Consume Send Item (key:" << key.FullKey() << "). " ; |
320 | // A message has already arrived and is queued in the table under |
321 | // this key. Consumes the message and invokes the done closure. |
322 | Item* item = queue->head; |
323 | |
324 | // Delete the queue when the last element has been consumed. |
325 | if (item->next == nullptr) { |
326 | DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). " ; |
327 | table_.erase(key_hash); |
328 | } else { |
329 | queue->head = item->next; |
330 | } |
331 | |
332 | // Make sure the ref-count of the rendezvous won't reach 0 while the |
333 | // done_callback is running, which would otherwise become deadlock: |
334 | // the done_callback waits for the Unref() to return, while the destructor |
335 | // wiats for the pending_callback_counter to reach 0. |
336 | core::RefCountPtr<const Rendezvous> rc_owner_ref; |
337 | if (rc_owner_) { |
338 | rc_owner_ref.reset(rc_owner_); |
339 | rc_owner_->Ref(); |
340 | } |
341 | pending_callback_counter_++; |
342 | // Invoke the done-callback, without holding the lock. |
343 | mu_.unlock(); |
344 | DCHECK_EQ(item->type, Item::kSend); |
345 | done(OkStatus(), item->args, recv_args, *item->send_state.value, |
346 | item->send_state.is_dead); |
347 | delete item; |
348 | { |
349 | mutex_lock l(mu_); |
350 | pending_callback_counter_--; |
351 | if (pending_callback_counter_ == 0) { |
352 | pending_callback_cond_var_.notify_all(); |
353 | } |
354 | } |
355 | } |
356 | |
357 | void LocalRendezvous::StartAbort(const Status& status) { |
358 | CHECK(!status.ok()); |
359 | Table table; |
360 | { |
361 | mutex_lock l(mu_); |
362 | status_.Update(status); |
363 | table_.swap(table); |
364 | } |
365 | for (auto& p : table) { |
366 | Item* item = p.second.head; |
367 | while (item != nullptr) { |
368 | if (item->type == Item::kRecv) { |
369 | (*item->recv_state.waiter)(status, Rendezvous::Args(), |
370 | Rendezvous::Args(), Tensor(), false); |
371 | } |
372 | Item* to_delete = item; |
373 | item = item->next; |
374 | delete to_delete; |
375 | } |
376 | } |
377 | } |
378 | |
379 | Status LocalRendezvous::status() { |
380 | mu_.lock(); |
381 | Status s = status_; |
382 | mu_.unlock(); |
383 | return s; |
384 | } |
385 | |
386 | } // namespace tensorflow |
387 | |