1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/common_runtime/buf_rendezvous.h" |
16 | |
17 | #include "absl/strings/numbers.h" |
18 | #include "absl/strings/str_cat.h" |
19 | #include "absl/strings/string_view.h" |
20 | #include "tensorflow/core/common_runtime/device.h" |
21 | #include "tensorflow/core/common_runtime/device_mgr.h" |
22 | #include "tensorflow/core/common_runtime/process_util.h" |
23 | #include "tensorflow/core/framework/cancellation.h" |
24 | #include "tensorflow/core/lib/core/errors.h" |
25 | #include "tensorflow/core/lib/core/notification.h" |
26 | |
27 | namespace tensorflow { |
28 | namespace { |
29 | void DeregisterCancellation(BufRendezvous::Hook* h) { |
30 | if (h->cancellation_manager != nullptr) { |
31 | h->cancellation_manager->DeregisterCallback(h->cancellation_token); |
32 | h->cancellation_manager = nullptr; |
33 | h->cancellation_token = CancellationManager::kInvalidToken; |
34 | } |
35 | } |
36 | } // namespace |
37 | |
38 | BufRendezvous::~BufRendezvous() { |
39 | mutex_lock l(mu_); |
40 | if (!hook_table_.empty()) { |
41 | PurgeTable(errors::Internal("Delete called on non-empty BufRendezvous" ), |
42 | &hook_table_); |
43 | } |
44 | } |
45 | |
46 | void BufRendezvous::StartAbort(const Status& s) { |
47 | CHECK(!s.ok()); |
48 | HookTable dummy_table; |
49 | { |
50 | mutex_lock l(mu_); |
51 | // Use a "derived" status as the status for the rendezvous. Derived |
52 | // status messages are ignored when aggregating errors across devices: this |
53 | // allows us to prefer our original status message over any cancellation |
54 | // related errors. |
55 | status_.Update(StatusGroup::MakeDerived(s)); |
56 | hook_table_.swap(dummy_table); |
57 | } |
58 | PurgeTable(s, &dummy_table); |
59 | } |
60 | |
61 | void BufRendezvous::PurgeTable(const Status& s, HookTable* table) { |
62 | for (auto& it : *table) { |
63 | Hook* h = it.second; |
64 | if (h->cancellation_manager != nullptr) { |
65 | h->cancellation_manager->TryDeregisterCallback(h->cancellation_token); |
66 | } |
67 | if (h->cons_cb != nullptr) { |
68 | h->cons_cb(s, nullptr); |
69 | } |
70 | if (h->prod_cb != nullptr) { |
71 | h->prod_cb(s); |
72 | } |
73 | delete h; |
74 | } |
75 | table->clear(); |
76 | } |
77 | |
78 | string BufRendezvous::Hook::DebugString() const { |
79 | return absl::StrCat("[dev:" , (prod_dev ? prod_dev->name() : "none" ), |
80 | ", ctx:" , reinterpret_cast<uint64>(prod_ctx), |
81 | ", val:" , reinterpret_cast<uint64>(prod_value), |
82 | ", pcb:" , reinterpret_cast<uint64>(&prod_cb), |
83 | ", ccb:" , reinterpret_cast<uint64>(&cons_cb), "]" ); |
84 | } |
85 | |
86 | void BufRendezvous::ProvideBuf(const string& key, Device* dev, |
87 | DeviceContext* dev_ctx, const Tensor* v, |
88 | const AllocatorAttributes& attr, |
89 | const ProducerCallback& done, |
90 | CancellationManager* cancellation_manager) { |
91 | Hook* h = nullptr; |
92 | Status providebuf_status; |
93 | do { |
94 | mutex_lock l(mu_); |
95 | if (!status_.ok()) { |
96 | providebuf_status = status_; |
97 | break; |
98 | } else { |
99 | CancellationToken cancellation_token = CancellationManager::kInvalidToken; |
100 | auto it = hook_table_.find(key); |
101 | if (it == hook_table_.end()) { |
102 | if (cancellation_manager != nullptr) { |
103 | cancellation_token = cancellation_manager->get_cancellation_token(); |
104 | } |
105 | h = new Hook(cancellation_manager, cancellation_token); |
106 | it = hook_table_.insert(std::make_pair(key, h)).first; |
107 | } else { |
108 | if (it->second->prod_cb != nullptr) { |
109 | providebuf_status = errors::Internal( |
110 | "BufRendezvous::ProvideBuf already called for key " , key); |
111 | break; |
112 | } |
113 | h = it->second; |
114 | } |
115 | // Populate Hook with all of the prod values. |
116 | h->prod_dev = dev; |
117 | h->prod_ctx = dev_ctx; |
118 | h->prod_value = v; |
119 | h->prod_attr = attr; |
120 | h->prod_cb = done; |
121 | if (h->cons_cb != nullptr) { |
122 | // If consumer is waiting, kick off right away, removing Hook from |
123 | // table. |
124 | hook_table_.erase(it); |
125 | } else { |
126 | if (cancellation_manager != nullptr && |
127 | !cancellation_manager->RegisterCallback( |
128 | cancellation_token, [this, key]() { CancelHook(key); })) { |
129 | // Register cancellation callback with CancellationManager. If it is |
130 | // already cancelled, call done immediately with cancelled status. |
131 | providebuf_status = errors::Cancelled( |
132 | "Operation was cancelled for BufRendezvous key " , key); |
133 | hook_table_.erase(it); |
134 | delete h; |
135 | } |
136 | h = nullptr; |
137 | } |
138 | } |
139 | } while (false); |
140 | if (h) { |
141 | DeregisterCancellation(h); |
142 | h->cons_cb(OkStatus(), h); |
143 | } |
144 | if (!providebuf_status.ok()) { |
145 | done(providebuf_status); |
146 | } |
147 | } |
148 | |
149 | void BufRendezvous::ConsumeBuf(const string& key, const string& device_name, |
150 | const uint64 device_incarnation, |
151 | const ConsumerCallback& done, |
152 | CancellationManager* cancellation_manager) { |
153 | // Check the incarnation in the request matches the current device |
154 | // incarnation of the producer. |
155 | Device* device; |
156 | Status consumebuf_status = dev_mgr_->LookupDevice(device_name, &device); |
157 | if (consumebuf_status.ok() && |
158 | device->attributes().incarnation() != device_incarnation) { |
159 | consumebuf_status = errors::FailedPrecondition( |
160 | "RecvBuf expects a different device incarnation: " , device_incarnation, |
161 | " vs. " , device->attributes().incarnation(), |
162 | ". Your worker job that contains the device (\"" , device_name, |
163 | "\") was probably restarted. Check your " |
164 | "worker job for the reason why it was restarted." ); |
165 | } |
166 | if (!consumebuf_status.ok()) { |
167 | done(consumebuf_status, nullptr); |
168 | return; |
169 | } |
170 | |
171 | Hook* existing_hook = nullptr; |
172 | do { |
173 | mutex_lock l(mu_); |
174 | if (!status_.ok()) { |
175 | consumebuf_status = status_; |
176 | break; |
177 | } |
178 | auto it = hook_table_.find(key); |
179 | if (it != hook_table_.end()) { |
180 | // Prepare to consume immediately. |
181 | if (it->second->cons_cb) { |
182 | consumebuf_status = |
183 | errors::Internal("Second consumer arrived for key " , key); |
184 | break; |
185 | } |
186 | existing_hook = it->second; |
187 | hook_table_.erase(it); |
188 | existing_hook->cons_cb = done; |
189 | } else { |
190 | // Hang consumer callback on the Hook. |
191 | CancellationToken cancellation_token = CancellationManager::kInvalidToken; |
192 | bool already_cancelled = false; |
193 | if (cancellation_manager != nullptr) { |
194 | cancellation_token = cancellation_manager->get_cancellation_token(); |
195 | already_cancelled = !cancellation_manager->RegisterCallback( |
196 | cancellation_token, [this, key]() { CancelHook(key); }); |
197 | } |
198 | if (already_cancelled) { |
199 | consumebuf_status = errors::Cancelled( |
200 | "Operation was cancelled for BufRendezvous key " , key); |
201 | } else { |
202 | Hook* h = new Hook(cancellation_manager, cancellation_token); |
203 | h->cons_cb = done; |
204 | it = hook_table_.insert(std::make_pair(key, h)).first; |
205 | return; |
206 | } |
207 | } |
208 | } while (false); |
209 | if (existing_hook) { |
210 | DeregisterCancellation(existing_hook); |
211 | existing_hook->cons_cb(OkStatus(), existing_hook); |
212 | return; |
213 | } |
214 | if (!consumebuf_status.ok()) { |
215 | done(consumebuf_status, nullptr); |
216 | return; |
217 | } |
218 | } |
219 | |
220 | void BufRendezvous::CancelHook(const string& key) { |
221 | Hook* h = nullptr; |
222 | { |
223 | mutex_lock l(mu_); |
224 | auto it = hook_table_.find(key); |
225 | if (it == hook_table_.end()) return; |
226 | h = it->second; |
227 | hook_table_.erase(it); |
228 | } |
229 | if (h != nullptr) { |
230 | auto s = errors::Cancelled("Operation was cancelled for BufRendezvous key " , |
231 | key); |
232 | if (h->prod_cb != nullptr) { |
233 | h->prod_cb(s); |
234 | } |
235 | if (h->cons_cb != nullptr) { |
236 | h->cons_cb(s, /*Hook=*/nullptr); |
237 | } |
238 | delete h; |
239 | } |
240 | } |
241 | |
242 | /*static*/ |
243 | void BufRendezvous::DoneWithHook(Hook* h) { |
244 | h->prod_cb(OkStatus()); |
245 | delete h; |
246 | } |
247 | |
248 | void BufRendezvous::LogContents() { |
249 | mutex_lock l(mu_); |
250 | LOG(INFO) << strings::StrCat("BufRendezvous " , |
251 | strings::Hex(reinterpret_cast<uint64>(this)), |
252 | " step_id=" , step_id_, " current contents:" ); |
253 | for (const auto& it : hook_table_) { |
254 | LOG(INFO) << it.first << ":" << it.second->DebugString(); |
255 | } |
256 | } |
257 | |
258 | } // namespace tensorflow |
259 | |