1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/tsl/framework/cancellation.h" |
17 | |
18 | #include <forward_list> |
19 | |
20 | #include "absl/memory/memory.h" |
21 | #include "tensorflow/tsl/platform/errors.h" |
22 | #include "tensorflow/tsl/platform/logging.h" |
23 | #include "tensorflow/tsl/platform/status.h" |
24 | |
25 | namespace tsl { |
26 | |
27 | const CancellationToken CancellationManager::kInvalidToken = -1; |
28 | |
29 | CancellationManager::CancellationManager() |
30 | : is_cancelling_(false), |
31 | is_cancelled_(false), |
32 | next_cancellation_token_(0) {} |
33 | |
34 | CancellationManager::CancellationManager(CancellationManager* parent) |
35 | : is_cancelling_(false), next_cancellation_token_(0), parent_(parent) { |
36 | is_cancelled_ = parent->RegisterChild(this); |
37 | } |
38 | |
39 | void CancellationManager::StartCancel() { |
40 | // An "OK" status will not be logged by a callback registered by |
41 | // RegisterCallbackWithErrorLogging. |
42 | StartCancelWithStatus(OkStatus()); |
43 | } |
44 | |
45 | void CancellationManager::StartCancelWithStatus(const Status& status) { |
46 | gtl::FlatMap<CancellationToken, CallbackConfiguration> callbacks_to_run; |
47 | std::forward_list<CancellationManager*> children_to_cancel; |
48 | Notification* cancelled_notification = nullptr; |
49 | { |
50 | mutex_lock l(mu_); |
51 | if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) { |
52 | return; |
53 | } |
54 | is_cancelling_ = true; |
55 | if (state_) { |
56 | std::swap(state_->callbacks, callbacks_to_run); |
57 | |
58 | // Remove all children from the list of children. |
59 | CancellationManager* child = state_->first_child; |
60 | while (child != nullptr) { |
61 | children_to_cancel.push_front(child); |
62 | child->is_removed_from_parent_ = true; |
63 | child = child->next_sibling_; |
64 | } |
65 | state_->first_child = nullptr; |
66 | |
67 | cancelled_notification = &state_->cancelled_notification; |
68 | } |
69 | } |
70 | // We call these callbacks without holding mu_, so that concurrent |
71 | // calls to DeregisterCallback, which can happen asynchronously, do |
72 | // not block. The callbacks remain valid because any concurrent call |
73 | // to DeregisterCallback will block until the |
74 | // cancelled_notification_ is notified. |
75 | for (auto key_and_value : callbacks_to_run) { |
76 | CallbackConfiguration& config = key_and_value.second; |
77 | if (!status.ok() && config.log_error) { |
78 | LOG(WARNING) << "Cancellation callback \"" << config.name |
79 | << "\" is triggered due to a " |
80 | << (StatusGroup::IsDerived(status) ? "derived" : "root" ) |
81 | << " error: " << status.ToString(); |
82 | } |
83 | config.callback(); |
84 | } |
85 | for (CancellationManager* child : children_to_cancel) { |
86 | child->StartCancelWithStatus(status); |
87 | } |
88 | { |
89 | mutex_lock l(mu_); |
90 | is_cancelling_ = false; |
91 | is_cancelled_.store(true, std::memory_order_release); |
92 | } |
93 | if (cancelled_notification) { |
94 | cancelled_notification->Notify(); |
95 | } |
96 | } |
97 | |
98 | bool CancellationManager::RegisterCallback(CancellationToken token, |
99 | CancelCallback callback) { |
100 | return RegisterCallbackConfig( |
101 | token, CallbackConfiguration{callback, "" , false}); |
102 | } |
103 | |
104 | bool CancellationManager::RegisterCallbackWithErrorLogging( |
105 | CancellationToken token, CancelCallback callback, |
106 | tsl::StringPiece callback_name) { |
107 | return RegisterCallbackConfig( |
108 | token, CallbackConfiguration{callback, std::string(callback_name), true}); |
109 | } |
110 | |
111 | bool CancellationManager::RegisterCallbackConfig(CancellationToken token, |
112 | CallbackConfiguration config) { |
113 | DCHECK_LT(token, next_cancellation_token_) << "Invalid cancellation token" ; |
114 | mutex_lock l(mu_); |
115 | bool should_register = !is_cancelled_ && !is_cancelling_; |
116 | if (should_register) { |
117 | if (!state_) { |
118 | state_ = absl::make_unique<State>(); |
119 | } |
120 | std::swap(state_->callbacks[token], config); |
121 | } |
122 | return should_register; |
123 | } |
124 | |
125 | bool CancellationManager::DeregisterCallback(CancellationToken token) { |
126 | mu_.lock(); |
127 | if (is_cancelled_) { |
128 | mu_.unlock(); |
129 | return false; |
130 | } else if (is_cancelling_) { |
131 | Notification* cancelled_notification = |
132 | state_ ? &state_->cancelled_notification : nullptr; |
133 | mu_.unlock(); |
134 | // Wait for all of the cancellation callbacks to be called. This |
135 | // wait ensures that the caller of DeregisterCallback does not |
136 | // return immediately and free objects that may be used in the |
137 | // execution of any currently pending callbacks in StartCancel. |
138 | if (cancelled_notification) { |
139 | cancelled_notification->WaitForNotification(); |
140 | } |
141 | return false; |
142 | } else { |
143 | if (state_) { |
144 | state_->callbacks.erase(token); |
145 | } |
146 | mu_.unlock(); |
147 | return true; |
148 | } |
149 | } |
150 | |
151 | bool CancellationManager::RegisterChild(CancellationManager* child) { |
152 | mutex_lock l(mu_); |
153 | if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) { |
154 | child->is_removed_from_parent_ = true; |
155 | return true; |
156 | } |
157 | |
158 | if (!state_) { |
159 | state_ = absl::make_unique<State>(); |
160 | } |
161 | |
162 | // Push `child` onto the front of the list of children. |
163 | CancellationManager* current_head = state_->first_child; |
164 | state_->first_child = child; |
165 | child->prev_sibling_ = nullptr; |
166 | child->next_sibling_ = current_head; |
167 | if (current_head) { |
168 | current_head->prev_sibling_ = child; |
169 | } |
170 | |
171 | return false; |
172 | } |
173 | |
174 | void CancellationManager::DeregisterChild(CancellationManager* child) { |
175 | DCHECK_EQ(child->parent_, this); |
176 | Notification* cancelled_notification = nullptr; |
177 | { |
178 | mutex_lock l(mu_); |
179 | if (!child->is_removed_from_parent_) { |
180 | // Remove the child from this manager's list of children. |
181 | DCHECK(state_); |
182 | |
183 | if (child->prev_sibling_ == nullptr) { |
184 | // The child was at the head of the list. |
185 | DCHECK_EQ(state_->first_child, child); |
186 | state_->first_child = child->next_sibling_; |
187 | } else { |
188 | child->prev_sibling_->next_sibling_ = child->next_sibling_; |
189 | } |
190 | |
191 | if (child->next_sibling_ != nullptr) { |
192 | child->next_sibling_->prev_sibling_ = child->prev_sibling_; |
193 | } |
194 | |
195 | child->is_removed_from_parent_ = true; |
196 | } |
197 | if (is_cancelling_) { |
198 | cancelled_notification = &state_->cancelled_notification; |
199 | } |
200 | } |
201 | |
202 | // Wait for an ongoing call to StartCancel() to finish. This wait ensures that |
203 | // the caller of DeregisterChild does not return immediately and free a child |
204 | // that may currently be being cancelled by StartCancel(). |
205 | if (cancelled_notification) { |
206 | cancelled_notification->WaitForNotification(); |
207 | } |
208 | } |
209 | |
210 | bool CancellationManager::TryDeregisterCallback(CancellationToken token) { |
211 | mutex_lock lock(mu_); |
212 | if (is_cancelled_ || is_cancelling_) { |
213 | return false; |
214 | } else { |
215 | if (state_) { |
216 | state_->callbacks.erase(token); |
217 | } |
218 | return true; |
219 | } |
220 | } |
221 | |
222 | CancellationManager::~CancellationManager() { |
223 | if (parent_) { |
224 | parent_->DeregisterChild(this); |
225 | } |
226 | if (state_) { |
227 | StartCancel(); |
228 | } |
229 | } |
230 | |
231 | bool CancellationManager::IsCancelling() { |
232 | mutex_lock lock(mu_); |
233 | return is_cancelling_; |
234 | } |
235 | |
236 | Status RegisterCancellationCallback(CancellationManager* cancellation_manager, |
237 | CancelCallback callback, |
238 | std::function<void()>* deregister_fn) { |
239 | if (cancellation_manager) { |
240 | CancellationToken token = cancellation_manager->get_cancellation_token(); |
241 | if (!cancellation_manager->RegisterCallback(token, std::move(callback))) { |
242 | return errors::Cancelled("Operation was cancelled" ); |
243 | } |
244 | *deregister_fn = [cancellation_manager, token]() { |
245 | cancellation_manager->DeregisterCallback(token); |
246 | }; |
247 | } else { |
248 | VLOG(1) << "Cancellation manager is not set. Cancellation callback will " |
249 | "not be registered." ; |
250 | *deregister_fn = []() {}; |
251 | } |
252 | return OkStatus(); |
253 | } |
254 | |
255 | } // end namespace tsl |
256 | |