1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/tsl/framework/cancellation.h"
17
18#include <forward_list>
19
20#include "absl/memory/memory.h"
21#include "tensorflow/tsl/platform/errors.h"
22#include "tensorflow/tsl/platform/logging.h"
23#include "tensorflow/tsl/platform/status.h"
24
25namespace tsl {
26
27const CancellationToken CancellationManager::kInvalidToken = -1;
28
29CancellationManager::CancellationManager()
30 : is_cancelling_(false),
31 is_cancelled_(false),
32 next_cancellation_token_(0) {}
33
34CancellationManager::CancellationManager(CancellationManager* parent)
35 : is_cancelling_(false), next_cancellation_token_(0), parent_(parent) {
36 is_cancelled_ = parent->RegisterChild(this);
37}
38
39void CancellationManager::StartCancel() {
40 // An "OK" status will not be logged by a callback registered by
41 // RegisterCallbackWithErrorLogging.
42 StartCancelWithStatus(OkStatus());
43}
44
45void CancellationManager::StartCancelWithStatus(const Status& status) {
46 gtl::FlatMap<CancellationToken, CallbackConfiguration> callbacks_to_run;
47 std::forward_list<CancellationManager*> children_to_cancel;
48 Notification* cancelled_notification = nullptr;
49 {
50 mutex_lock l(mu_);
51 if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) {
52 return;
53 }
54 is_cancelling_ = true;
55 if (state_) {
56 std::swap(state_->callbacks, callbacks_to_run);
57
58 // Remove all children from the list of children.
59 CancellationManager* child = state_->first_child;
60 while (child != nullptr) {
61 children_to_cancel.push_front(child);
62 child->is_removed_from_parent_ = true;
63 child = child->next_sibling_;
64 }
65 state_->first_child = nullptr;
66
67 cancelled_notification = &state_->cancelled_notification;
68 }
69 }
70 // We call these callbacks without holding mu_, so that concurrent
71 // calls to DeregisterCallback, which can happen asynchronously, do
72 // not block. The callbacks remain valid because any concurrent call
73 // to DeregisterCallback will block until the
74 // cancelled_notification_ is notified.
75 for (auto key_and_value : callbacks_to_run) {
76 CallbackConfiguration& config = key_and_value.second;
77 if (!status.ok() && config.log_error) {
78 LOG(WARNING) << "Cancellation callback \"" << config.name
79 << "\" is triggered due to a "
80 << (StatusGroup::IsDerived(status) ? "derived" : "root")
81 << " error: " << status.ToString();
82 }
83 config.callback();
84 }
85 for (CancellationManager* child : children_to_cancel) {
86 child->StartCancelWithStatus(status);
87 }
88 {
89 mutex_lock l(mu_);
90 is_cancelling_ = false;
91 is_cancelled_.store(true, std::memory_order_release);
92 }
93 if (cancelled_notification) {
94 cancelled_notification->Notify();
95 }
96}
97
98bool CancellationManager::RegisterCallback(CancellationToken token,
99 CancelCallback callback) {
100 return RegisterCallbackConfig(
101 token, CallbackConfiguration{callback, "", false});
102}
103
104bool CancellationManager::RegisterCallbackWithErrorLogging(
105 CancellationToken token, CancelCallback callback,
106 tsl::StringPiece callback_name) {
107 return RegisterCallbackConfig(
108 token, CallbackConfiguration{callback, std::string(callback_name), true});
109}
110
111bool CancellationManager::RegisterCallbackConfig(CancellationToken token,
112 CallbackConfiguration config) {
113 DCHECK_LT(token, next_cancellation_token_) << "Invalid cancellation token";
114 mutex_lock l(mu_);
115 bool should_register = !is_cancelled_ && !is_cancelling_;
116 if (should_register) {
117 if (!state_) {
118 state_ = absl::make_unique<State>();
119 }
120 std::swap(state_->callbacks[token], config);
121 }
122 return should_register;
123}
124
125bool CancellationManager::DeregisterCallback(CancellationToken token) {
126 mu_.lock();
127 if (is_cancelled_) {
128 mu_.unlock();
129 return false;
130 } else if (is_cancelling_) {
131 Notification* cancelled_notification =
132 state_ ? &state_->cancelled_notification : nullptr;
133 mu_.unlock();
134 // Wait for all of the cancellation callbacks to be called. This
135 // wait ensures that the caller of DeregisterCallback does not
136 // return immediately and free objects that may be used in the
137 // execution of any currently pending callbacks in StartCancel.
138 if (cancelled_notification) {
139 cancelled_notification->WaitForNotification();
140 }
141 return false;
142 } else {
143 if (state_) {
144 state_->callbacks.erase(token);
145 }
146 mu_.unlock();
147 return true;
148 }
149}
150
151bool CancellationManager::RegisterChild(CancellationManager* child) {
152 mutex_lock l(mu_);
153 if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) {
154 child->is_removed_from_parent_ = true;
155 return true;
156 }
157
158 if (!state_) {
159 state_ = absl::make_unique<State>();
160 }
161
162 // Push `child` onto the front of the list of children.
163 CancellationManager* current_head = state_->first_child;
164 state_->first_child = child;
165 child->prev_sibling_ = nullptr;
166 child->next_sibling_ = current_head;
167 if (current_head) {
168 current_head->prev_sibling_ = child;
169 }
170
171 return false;
172}
173
174void CancellationManager::DeregisterChild(CancellationManager* child) {
175 DCHECK_EQ(child->parent_, this);
176 Notification* cancelled_notification = nullptr;
177 {
178 mutex_lock l(mu_);
179 if (!child->is_removed_from_parent_) {
180 // Remove the child from this manager's list of children.
181 DCHECK(state_);
182
183 if (child->prev_sibling_ == nullptr) {
184 // The child was at the head of the list.
185 DCHECK_EQ(state_->first_child, child);
186 state_->first_child = child->next_sibling_;
187 } else {
188 child->prev_sibling_->next_sibling_ = child->next_sibling_;
189 }
190
191 if (child->next_sibling_ != nullptr) {
192 child->next_sibling_->prev_sibling_ = child->prev_sibling_;
193 }
194
195 child->is_removed_from_parent_ = true;
196 }
197 if (is_cancelling_) {
198 cancelled_notification = &state_->cancelled_notification;
199 }
200 }
201
202 // Wait for an ongoing call to StartCancel() to finish. This wait ensures that
203 // the caller of DeregisterChild does not return immediately and free a child
204 // that may currently be being cancelled by StartCancel().
205 if (cancelled_notification) {
206 cancelled_notification->WaitForNotification();
207 }
208}
209
210bool CancellationManager::TryDeregisterCallback(CancellationToken token) {
211 mutex_lock lock(mu_);
212 if (is_cancelled_ || is_cancelling_) {
213 return false;
214 } else {
215 if (state_) {
216 state_->callbacks.erase(token);
217 }
218 return true;
219 }
220}
221
222CancellationManager::~CancellationManager() {
223 if (parent_) {
224 parent_->DeregisterChild(this);
225 }
226 if (state_) {
227 StartCancel();
228 }
229}
230
231bool CancellationManager::IsCancelling() {
232 mutex_lock lock(mu_);
233 return is_cancelling_;
234}
235
236Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
237 CancelCallback callback,
238 std::function<void()>* deregister_fn) {
239 if (cancellation_manager) {
240 CancellationToken token = cancellation_manager->get_cancellation_token();
241 if (!cancellation_manager->RegisterCallback(token, std::move(callback))) {
242 return errors::Cancelled("Operation was cancelled");
243 }
244 *deregister_fn = [cancellation_manager, token]() {
245 cancellation_manager->DeregisterCallback(token);
246 };
247 } else {
248 VLOG(1) << "Cancellation manager is not set. Cancellation callback will "
249 "not be registered.";
250 *deregister_fn = []() {};
251 }
252 return OkStatus();
253}
254
255} // end namespace tsl
256