1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/common_runtime/buf_rendezvous.h"
16
17#include "absl/strings/numbers.h"
18#include "absl/strings/str_cat.h"
19#include "absl/strings/string_view.h"
20#include "tensorflow/core/common_runtime/device.h"
21#include "tensorflow/core/common_runtime/device_mgr.h"
22#include "tensorflow/core/common_runtime/process_util.h"
23#include "tensorflow/core/framework/cancellation.h"
24#include "tensorflow/core/lib/core/errors.h"
25#include "tensorflow/core/lib/core/notification.h"
26
27namespace tensorflow {
28namespace {
29void DeregisterCancellation(BufRendezvous::Hook* h) {
30 if (h->cancellation_manager != nullptr) {
31 h->cancellation_manager->DeregisterCallback(h->cancellation_token);
32 h->cancellation_manager = nullptr;
33 h->cancellation_token = CancellationManager::kInvalidToken;
34 }
35}
36} // namespace
37
38BufRendezvous::~BufRendezvous() {
39 mutex_lock l(mu_);
40 if (!hook_table_.empty()) {
41 PurgeTable(errors::Internal("Delete called on non-empty BufRendezvous"),
42 &hook_table_);
43 }
44}
45
46void BufRendezvous::StartAbort(const Status& s) {
47 CHECK(!s.ok());
48 HookTable dummy_table;
49 {
50 mutex_lock l(mu_);
51 // Use a "derived" status as the status for the rendezvous. Derived
52 // status messages are ignored when aggregating errors across devices: this
53 // allows us to prefer our original status message over any cancellation
54 // related errors.
55 status_.Update(StatusGroup::MakeDerived(s));
56 hook_table_.swap(dummy_table);
57 }
58 PurgeTable(s, &dummy_table);
59}
60
61void BufRendezvous::PurgeTable(const Status& s, HookTable* table) {
62 for (auto& it : *table) {
63 Hook* h = it.second;
64 if (h->cancellation_manager != nullptr) {
65 h->cancellation_manager->TryDeregisterCallback(h->cancellation_token);
66 }
67 if (h->cons_cb != nullptr) {
68 h->cons_cb(s, nullptr);
69 }
70 if (h->prod_cb != nullptr) {
71 h->prod_cb(s);
72 }
73 delete h;
74 }
75 table->clear();
76}
77
78string BufRendezvous::Hook::DebugString() const {
79 return absl::StrCat("[dev:", (prod_dev ? prod_dev->name() : "none"),
80 ", ctx:", reinterpret_cast<uint64>(prod_ctx),
81 ", val:", reinterpret_cast<uint64>(prod_value),
82 ", pcb:", reinterpret_cast<uint64>(&prod_cb),
83 ", ccb:", reinterpret_cast<uint64>(&cons_cb), "]");
84}
85
86void BufRendezvous::ProvideBuf(const string& key, Device* dev,
87 DeviceContext* dev_ctx, const Tensor* v,
88 const AllocatorAttributes& attr,
89 const ProducerCallback& done,
90 CancellationManager* cancellation_manager) {
91 Hook* h = nullptr;
92 Status providebuf_status;
93 do {
94 mutex_lock l(mu_);
95 if (!status_.ok()) {
96 providebuf_status = status_;
97 break;
98 } else {
99 CancellationToken cancellation_token = CancellationManager::kInvalidToken;
100 auto it = hook_table_.find(key);
101 if (it == hook_table_.end()) {
102 if (cancellation_manager != nullptr) {
103 cancellation_token = cancellation_manager->get_cancellation_token();
104 }
105 h = new Hook(cancellation_manager, cancellation_token);
106 it = hook_table_.insert(std::make_pair(key, h)).first;
107 } else {
108 if (it->second->prod_cb != nullptr) {
109 providebuf_status = errors::Internal(
110 "BufRendezvous::ProvideBuf already called for key ", key);
111 break;
112 }
113 h = it->second;
114 }
115 // Populate Hook with all of the prod values.
116 h->prod_dev = dev;
117 h->prod_ctx = dev_ctx;
118 h->prod_value = v;
119 h->prod_attr = attr;
120 h->prod_cb = done;
121 if (h->cons_cb != nullptr) {
122 // If consumer is waiting, kick off right away, removing Hook from
123 // table.
124 hook_table_.erase(it);
125 } else {
126 if (cancellation_manager != nullptr &&
127 !cancellation_manager->RegisterCallback(
128 cancellation_token, [this, key]() { CancelHook(key); })) {
129 // Register cancellation callback with CancellationManager. If it is
130 // already cancelled, call done immediately with cancelled status.
131 providebuf_status = errors::Cancelled(
132 "Operation was cancelled for BufRendezvous key ", key);
133 hook_table_.erase(it);
134 delete h;
135 }
136 h = nullptr;
137 }
138 }
139 } while (false);
140 if (h) {
141 DeregisterCancellation(h);
142 h->cons_cb(OkStatus(), h);
143 }
144 if (!providebuf_status.ok()) {
145 done(providebuf_status);
146 }
147}
148
149void BufRendezvous::ConsumeBuf(const string& key, const string& device_name,
150 const uint64 device_incarnation,
151 const ConsumerCallback& done,
152 CancellationManager* cancellation_manager) {
153 // Check the incarnation in the request matches the current device
154 // incarnation of the producer.
155 Device* device;
156 Status consumebuf_status = dev_mgr_->LookupDevice(device_name, &device);
157 if (consumebuf_status.ok() &&
158 device->attributes().incarnation() != device_incarnation) {
159 consumebuf_status = errors::FailedPrecondition(
160 "RecvBuf expects a different device incarnation: ", device_incarnation,
161 " vs. ", device->attributes().incarnation(),
162 ". Your worker job that contains the device (\"", device_name,
163 "\") was probably restarted. Check your "
164 "worker job for the reason why it was restarted.");
165 }
166 if (!consumebuf_status.ok()) {
167 done(consumebuf_status, nullptr);
168 return;
169 }
170
171 Hook* existing_hook = nullptr;
172 do {
173 mutex_lock l(mu_);
174 if (!status_.ok()) {
175 consumebuf_status = status_;
176 break;
177 }
178 auto it = hook_table_.find(key);
179 if (it != hook_table_.end()) {
180 // Prepare to consume immediately.
181 if (it->second->cons_cb) {
182 consumebuf_status =
183 errors::Internal("Second consumer arrived for key ", key);
184 break;
185 }
186 existing_hook = it->second;
187 hook_table_.erase(it);
188 existing_hook->cons_cb = done;
189 } else {
190 // Hang consumer callback on the Hook.
191 CancellationToken cancellation_token = CancellationManager::kInvalidToken;
192 bool already_cancelled = false;
193 if (cancellation_manager != nullptr) {
194 cancellation_token = cancellation_manager->get_cancellation_token();
195 already_cancelled = !cancellation_manager->RegisterCallback(
196 cancellation_token, [this, key]() { CancelHook(key); });
197 }
198 if (already_cancelled) {
199 consumebuf_status = errors::Cancelled(
200 "Operation was cancelled for BufRendezvous key ", key);
201 } else {
202 Hook* h = new Hook(cancellation_manager, cancellation_token);
203 h->cons_cb = done;
204 it = hook_table_.insert(std::make_pair(key, h)).first;
205 return;
206 }
207 }
208 } while (false);
209 if (existing_hook) {
210 DeregisterCancellation(existing_hook);
211 existing_hook->cons_cb(OkStatus(), existing_hook);
212 return;
213 }
214 if (!consumebuf_status.ok()) {
215 done(consumebuf_status, nullptr);
216 return;
217 }
218}
219
220void BufRendezvous::CancelHook(const string& key) {
221 Hook* h = nullptr;
222 {
223 mutex_lock l(mu_);
224 auto it = hook_table_.find(key);
225 if (it == hook_table_.end()) return;
226 h = it->second;
227 hook_table_.erase(it);
228 }
229 if (h != nullptr) {
230 auto s = errors::Cancelled("Operation was cancelled for BufRendezvous key ",
231 key);
232 if (h->prod_cb != nullptr) {
233 h->prod_cb(s);
234 }
235 if (h->cons_cb != nullptr) {
236 h->cons_cb(s, /*Hook=*/nullptr);
237 }
238 delete h;
239 }
240}
241
242/*static*/
243void BufRendezvous::DoneWithHook(Hook* h) {
244 h->prod_cb(OkStatus());
245 delete h;
246}
247
248void BufRendezvous::LogContents() {
249 mutex_lock l(mu_);
250 LOG(INFO) << strings::StrCat("BufRendezvous ",
251 strings::Hex(reinterpret_cast<uint64>(this)),
252 " step_id=", step_id_, " current contents:");
253 for (const auto& it : hook_table_) {
254 LOG(INFO) << it.first << ":" << it.second->DebugString();
255 }
256}
257
258} // namespace tensorflow
259