1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/local_rendezvous.h"
17
18#include "tensorflow/core/framework/allocator.h"
19#include "tensorflow/core/framework/types.h"
20#include "tensorflow/core/lib/core/errors.h"
21#include "tensorflow/core/lib/core/notification.h"
22#include "tensorflow/core/lib/gtl/manual_constructor.h"
23#include "tensorflow/core/lib/monitoring/counter.h"
24#include "tensorflow/core/lib/strings/numbers.h"
25#include "tensorflow/core/lib/strings/str_util.h"
26#include "tensorflow/core/platform/logging.h"
27#include "tensorflow/core/platform/mutex.h"
28#include "tensorflow/core/platform/refcount.h"
29#include "tensorflow/core/platform/types.h"
30
31namespace tensorflow {
32
33// Represents a blocked Send() or Recv() call in the rendezvous.
34struct LocalRendezvous::Item {
35 enum Type { kSend = 0, kRecv = 1 };
36
37 Item(Rendezvous::Args send_args, const Tensor& value, bool is_dead)
38 : Item(send_args, kSend) {
39 send_state.value.Init(value);
40 send_state.is_dead = is_dead;
41 }
42
43 Item(Rendezvous::Args recv_args, Rendezvous::DoneCallback waiter,
44 CancellationToken cancellation_token)
45 : Item(recv_args, kRecv) {
46 recv_state.waiter.Init(std::move(waiter));
47 recv_state.cancellation_token = cancellation_token;
48 }
49
50 ~Item() {
51 if (args.device_context) {
52 args.device_context->Unref();
53 }
54 if (type == kSend) {
55 send_state.value.Destroy();
56 } else {
57 recv_state.waiter.Destroy();
58 }
59 }
60
61 const Rendezvous::Args args;
62 const Type type;
63
64 // Link to next item in an ItemQueue.
65 Item* next = nullptr;
66
67 // The validity of `send_state` or `recv_state` is determined by `type ==
68 // kSend` or `type == kRecv` respectively.
69 union {
70 struct {
71 ManualConstructor<Tensor> value;
72 bool is_dead;
73 } send_state;
74 struct {
75 ManualConstructor<Rendezvous::DoneCallback> waiter;
76 CancellationToken cancellation_token;
77 } recv_state;
78 };
79
80 private:
81 Item(Rendezvous::Args args, Type type) : args(args), type(type) {
82 if (args.device_context) {
83 args.device_context->Ref();
84 }
85 }
86};
87
88void LocalRendezvous::ItemQueue::push_back(Item* item) {
89 if (TF_PREDICT_TRUE(head == nullptr)) {
90 // The queue is empty.
91 head = item;
92 tail = item;
93 } else {
94 DCHECK_EQ(tail->type, item->type);
95 tail->next = item;
96 tail = item;
97 }
98}
99
100LocalRendezvous::~LocalRendezvous() {
101 // Before destroying this rendezvous instance, make sure all the done-callback
102 // calls have finished and the tensors have been released from the queue.
103 {
104 mutex_lock l(mu_);
105 while (pending_callback_counter_ != 0) {
106 pending_callback_cond_var_.wait_for(l, std::chrono::milliseconds(50));
107 }
108 }
109
110 if (!table_.empty()) {
111 StartAbort(errors::Cancelled("LocalRendezvous deleted"));
112 }
113}
114
115namespace {
116uint64 KeyHash(const StringPiece& k) { return Hash64(k.data(), k.size()); }
117} // namespace
118
119Status LocalRendezvous::Send(const Rendezvous::ParsedKey& key,
120 const Rendezvous::Args& send_args,
121 const Tensor& val, const bool is_dead) {
122 uint64 key_hash = KeyHash(key.FullKey());
123 DVLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey();
124
125 if (is_dead) {
126 static auto* rendezvous_dead_values_sent = monitoring::Counter<2>::New(
127 "/tensorflow/core/rendezvous_dead_values_sent",
128 "The number of dead values sent between a pair of devices.",
129 "send_device", "recv_device");
130 rendezvous_dead_values_sent
131 ->GetCell(string(key.src_device), string(key.dst_device))
132 ->IncrementBy(1);
133 }
134
135 mu_.lock();
136 if (!status_.ok()) {
137 // Rendezvous has been aborted.
138 Status s = status_;
139 mu_.unlock();
140 return s;
141 }
142
143 ItemQueue* queue = &table_[key_hash];
144 if (queue->head == nullptr || queue->head->type == Item::kSend) {
145 // There is no waiter for this message. Append the message
146 // into the queue. The waiter will pick it up when arrives.
147 // Only send-related fields need to be filled.
148 // TODO(b/143786186): Investigate moving the allocation of `Item` outside
149 // the lock.
150 DVLOG(2) << "Enqueue Send Item (key:" << key.FullKey() << "). ";
151 queue->push_back(new Item(send_args, val, is_dead));
152 mu_.unlock();
153 return OkStatus();
154 }
155
156 DVLOG(2) << "Consume Recv Item (key:" << key.FullKey() << "). ";
157 // There is an earliest waiter to consume this message.
158 Item* item = queue->head;
159
160 // Delete the queue when the last element has been consumed.
161 if (item->next == nullptr) {
162 DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). ";
163 table_.erase(key_hash);
164 } else {
165 queue->head = item->next;
166 }
167
168 // Make sure the ref-count of the rendezvous won't reach 0 while the
169 // done_callback is running, which would otherwise become deadlock:
170 // the done_callback waits for the Unref() to return, while the destructor
171 // wiats for the pending_callback_counter to reach 0.
172 core::RefCountPtr<const Rendezvous> rc_owner_ref;
173 if (rc_owner_) {
174 rc_owner_ref.reset(rc_owner_);
175 rc_owner_->Ref();
176 }
177 pending_callback_counter_++;
178 // Invoke the done-callback, without holding the lock.
179 mu_.unlock();
180 DCHECK_EQ(item->type, Item::kRecv);
181 (*item->recv_state.waiter)(OkStatus(), send_args, item->args, val, is_dead);
182 delete item;
183 {
184 mutex_lock l(mu_);
185 pending_callback_counter_--;
186 if (pending_callback_counter_ == 0) {
187 pending_callback_cond_var_.notify_all();
188 }
189 }
190 return OkStatus();
191}
192
193void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
194 const Rendezvous::Args& recv_args,
195 Rendezvous::DoneCallback done) {
196 uint64 key_hash = KeyHash(key.FullKey());
197 DVLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey();
198
199 mu_.lock();
200 if (!status_.ok()) {
201 // Rendezvous has been aborted.
202 Status s = status_;
203 mu_.unlock();
204 done(s, Rendezvous::Args(), recv_args, Tensor(), false);
205 return;
206 }
207
208 ItemQueue* queue = &table_[key_hash];
209 if (queue->head == nullptr || queue->head->type == Item::kRecv) {
210 // There is no message to pick up.
211 // Only recv-related fields need to be filled.
212 CancellationManager* cm = recv_args.cancellation_manager;
213 CancellationToken token = CancellationManager::kInvalidToken;
214 bool already_cancelled = false;
215 if (cm != nullptr) {
216 // Increment the refcount when cancellation manager is present, to make
217 // sure the rendezvous outlives the recv and its cancel callbacks.
218 // This refcount is dropped in exactly one of the following cases:
219 // (1) Recv registers cancellation callback to cm, and then cm is
220 // cancelled, unref in the cancellation callback;
221 // (2) Recv registers cancellation callback to cm, but cm is already
222 // cancelled, unref in the already_cancelled check;
223 // (3) Recv is successful, and item done callback finishes deregistering
224 // the cancellation callback, unref in the item done callback;
225 // (4) Recv is successful, but the item done callback fails to deregister
226 // the cancellation callback because cm already StartCancel, in this
227 // case the cancellation callback will be invoked by the cm anyway,
228 // unref in the cancellation callback.
229 if (rc_owner_) rc_owner_->Ref();
230 token = cm->get_cancellation_token();
231 already_cancelled = !cm->RegisterCallback(token, [this, token, key_hash] {
232 Item* item = nullptr;
233 {
234 mutex_lock l(mu_);
235 ItemQueue* queue = &table_[key_hash];
236 // Find an item in the queue with a cancellation token that matches
237 // `token`, and remove it.
238 if (queue->head != nullptr && queue->head->type == Item::kRecv) {
239 for (Item *prev = nullptr, *curr = queue->head; curr != nullptr;
240 prev = curr, curr = curr->next) {
241 if (curr->recv_state.cancellation_token == token) {
242 item = curr;
243 if (queue->head->next == nullptr) {
244 // We have a single-element queue, so we can erase it from
245 // the table.
246 table_.erase(key_hash);
247 } else {
248 // Remove the current item from the queue.
249 if (curr == queue->head) {
250 DCHECK_EQ(prev, nullptr);
251 queue->head = curr->next;
252 } else {
253 DCHECK_NE(prev, nullptr);
254 prev->next = curr->next;
255 }
256 if (queue->tail == curr) {
257 queue->tail = prev;
258 }
259 }
260 break;
261 }
262 }
263 }
264 }
265
266 if (item != nullptr) {
267 (*item->recv_state.waiter)(
268 StatusGroup::MakeDerived(
269 errors::Cancelled("RecvAsync is cancelled.")),
270 Rendezvous::Args(), item->args, Tensor(), /*is_dead=*/false);
271 delete item;
272 }
273 // Unref case (1) and (4)
274 if (rc_owner_) rc_owner_->Unref();
275 });
276 }
277 if (already_cancelled) {
278 mu_.unlock();
279 // Unref case (2)
280 if (rc_owner_) rc_owner_->Unref();
281 done(StatusGroup::MakeDerived(
282 errors::Cancelled("RecvAsync is cancelled.")),
283 Rendezvous::Args(), recv_args, Tensor(), /*is_dead=*/false);
284 return;
285 }
286
287 DVLOG(2) << "Enqueue Recv Item (key:" << key.FullKey() << "). ";
288
289 // TODO(b/143786186): Investigate moving the allocation of `Item` outside
290 // the lock.
291 if (cm != nullptr) {
292 // NOTE(mrry): We must wrap `done` with code that deregisters the
293 // cancellation callback before calling the `done` callback, because the
294 // cancellation manager may no longer be live after `done` is called.
295 queue->push_back(new Item(
296 recv_args,
297 [this, cm, token, done = std::move(done)](
298 const Status& s, const Rendezvous::Args& send_args,
299 const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
300 // TryDeregisterCallback returns true when the cancellation callback
301 // is successfully deregistered. If it fails because the CM already
302 // StartAbort, Unref will happen inside the cancellation callback
303 // when called by the CM.
304 if (cm->TryDeregisterCallback(token)) {
305 // Unref case (3)
306 if (this->rc_owner_) this->rc_owner_->Unref();
307 }
308 done(s, send_args, recv_args, v, dead);
309 },
310 token));
311 } else {
312 queue->push_back(new Item(recv_args, std::move(done), token));
313 }
314
315 mu_.unlock();
316 return;
317 }
318
319 DVLOG(2) << "Consume Send Item (key:" << key.FullKey() << "). ";
320 // A message has already arrived and is queued in the table under
321 // this key. Consumes the message and invokes the done closure.
322 Item* item = queue->head;
323
324 // Delete the queue when the last element has been consumed.
325 if (item->next == nullptr) {
326 DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). ";
327 table_.erase(key_hash);
328 } else {
329 queue->head = item->next;
330 }
331
332 // Make sure the ref-count of the rendezvous won't reach 0 while the
333 // done_callback is running, which would otherwise become deadlock:
334 // the done_callback waits for the Unref() to return, while the destructor
335 // wiats for the pending_callback_counter to reach 0.
336 core::RefCountPtr<const Rendezvous> rc_owner_ref;
337 if (rc_owner_) {
338 rc_owner_ref.reset(rc_owner_);
339 rc_owner_->Ref();
340 }
341 pending_callback_counter_++;
342 // Invoke the done-callback, without holding the lock.
343 mu_.unlock();
344 DCHECK_EQ(item->type, Item::kSend);
345 done(OkStatus(), item->args, recv_args, *item->send_state.value,
346 item->send_state.is_dead);
347 delete item;
348 {
349 mutex_lock l(mu_);
350 pending_callback_counter_--;
351 if (pending_callback_counter_ == 0) {
352 pending_callback_cond_var_.notify_all();
353 }
354 }
355}
356
357void LocalRendezvous::StartAbort(const Status& status) {
358 CHECK(!status.ok());
359 Table table;
360 {
361 mutex_lock l(mu_);
362 status_.Update(status);
363 table_.swap(table);
364 }
365 for (auto& p : table) {
366 Item* item = p.second.head;
367 while (item != nullptr) {
368 if (item->type == Item::kRecv) {
369 (*item->recv_state.waiter)(status, Rendezvous::Args(),
370 Rendezvous::Args(), Tensor(), false);
371 }
372 Item* to_delete = item;
373 item = item->next;
374 delete to_delete;
375 }
376 }
377}
378
379Status LocalRendezvous::status() {
380 mu_.lock();
381 Status s = status_;
382 mu_.unlock();
383 return s;
384}
385
386} // namespace tensorflow
387