1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_TSL_FRAMEWORK_CANCELLATION_H_
17#define TENSORFLOW_TSL_FRAMEWORK_CANCELLATION_H_
18
19#include <atomic>
20#include <functional>
21
22#include "tensorflow/tsl/lib/gtl/flatmap.h"
23#include "tensorflow/tsl/platform/hash.h"
24#include "tensorflow/tsl/platform/mutex.h"
25#include "tensorflow/tsl/platform/notification.h"
26#include "tensorflow/tsl/platform/status.h"
27#include "tensorflow/tsl/platform/stringpiece.h"
28#include "tensorflow/tsl/platform/thread_annotations.h"
29#include "tensorflow/tsl/platform/types.h"
30
31namespace tsl {
32
33// A token that can be used to register and deregister a
34// CancelCallback with a CancellationManager.
35//
36// CancellationToken values must be created by a call to
37// CancellationManager::get_cancellation_token.
38typedef int64_t CancellationToken;
39
40// A callback that is invoked when a step is canceled.
41//
42// NOTE(mrry): See caveats about CancelCallback implementations in the
43// comment for CancellationManager::RegisterCallback.
44typedef std::function<void()> CancelCallback;
45
46// This class should never simultaneously be used as the cancellation manager
47// for two separate sets of executions (i.e two separate steps, or two separate
48// function executions).
49class CancellationManager {
50 public:
51 // A value that won't be returned by get_cancellation_token().
52 static const CancellationToken kInvalidToken;
53
54 CancellationManager();
55
56 // Constructs a new CancellationManager that is a "child" of `*parent`.
57 //
58 // If `*parent` is cancelled, `*this` will be cancelled. `*parent` must
59 // outlive the created CancellationManager.
60 explicit CancellationManager(CancellationManager* parent);
61
62 ~CancellationManager();
63
64 // Run all callbacks associated with this manager.
65 void StartCancel();
66
67 // Run all callbacks associated with this manager with a status.
68 // Currently the status is for logging purpose only. See also
69 // CancellationManager::RegisterCallbackWithErrorLogging.
70 void StartCancelWithStatus(const Status& status);
71
72 // Returns true iff StartCancel() has been called.
73 bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); }
74
75 // Returns a token that must be used in calls to RegisterCallback
76 // and DeregisterCallback.
77 CancellationToken get_cancellation_token() {
78 return next_cancellation_token_.fetch_add(1);
79 }
80
81 // Attempts to register the given callback to be invoked when this
82 // manager is cancelled. Returns true if the callback was
83 // registered; returns false if this manager was already cancelled,
84 // and the callback was not registered.
85 //
86 // If this method returns false, it is the caller's responsibility
87 // to perform any cancellation cleanup.
88 //
89 // This method is tricky to use correctly. The following usage pattern
90 // is recommended:
91 //
92 // class ObjectWithCancellableOperation {
93 // mutex mu_;
94 // void CancellableOperation(CancellationManager* cm,
95 // std::function<void(Status)> callback) {
96 // bool already_cancelled;
97 // CancellationToken token = cm->get_cancellation_token();
98 // {
99 // mutex_lock(mu_);
100 // already_cancelled = !cm->RegisterCallback(
101 // [this, token]() { Cancel(token); });
102 // if (!already_cancelled) {
103 // // Issue asynchronous operation. Associate the pending operation
104 // // with `token` in some object state, or provide another way for
105 // // the Cancel method to look up the operation for cancellation.
106 // // Ensure that `cm->DeregisterCallback(token)` is called without
107 // // holding `mu_`, before `callback` is invoked.
108 // // ...
109 // }
110 // }
111 // if (already_cancelled) {
112 // callback(errors::Cancelled("Operation was cancelled"));
113 // }
114 // }
115 //
116 // void Cancel(CancellationToken token) {
117 // mutex_lock(mu_);
118 // // Take action to cancel the operation with the given cancellation
119 // // token.
120 // }
121 //
122 // NOTE(mrry): The caller should take care that (i) the calling code
123 // is robust to `callback` being invoked asynchronously (e.g. from
124 // another thread), (ii) `callback` is deregistered by a call to
125 // this->DeregisterCallback(token) when the operation completes
126 // successfully, and (iii) `callback` does not invoke any method
127 // on this cancellation manager. Furthermore, it is important that
128 // the eventual caller of the complementary DeregisterCallback does not
129 // hold any mutexes that are required by `callback`.
130 bool RegisterCallback(CancellationToken token, CancelCallback callback);
131
132 // Similar to RegisterCallback, but if the cancellation manager starts a
133 // cancellation with an error status, it will log the error status before
134 // invoking the callback. `callback_name` is a human-readable name of the
135 // callback, which will be displayed on the log.
136 bool RegisterCallbackWithErrorLogging(CancellationToken token,
137 CancelCallback callback,
138 tsl::StringPiece callback_name);
139
140 // Deregister the callback that, when registered, was associated
141 // with the given cancellation token. Returns true iff the callback
142 // was deregistered and will not be invoked; otherwise returns false
143 // after the callback has been invoked, blocking if necessary.
144 //
145 // NOTE(mrry): This method may block if cancellation is in progress.
146 // The caller of this method must not hold any mutexes that are required
147 // to invoke any cancellation callback that has been registered with this
148 // cancellation manager.
149 bool DeregisterCallback(CancellationToken token);
150
151 // Deregister the callback that, when registered, was associated
152 // with the given cancellation token. Returns true iff the callback
153 // was deregistered and will not be invoked; otherwise returns false
154 // immediately, with no guarantee that the callback has completed.
155 //
156 // This method is guaranteed to return true if StartCancel has not been
157 // called.
158 bool TryDeregisterCallback(CancellationToken token);
159
160 // Returns true iff cancellation is in progress.
161 bool IsCancelling();
162
163 private:
164 struct CallbackConfiguration {
165 CancelCallback callback;
166 std::string name;
167 bool log_error = false;
168 };
169
170 struct State {
171 Notification cancelled_notification;
172 gtl::FlatMap<CancellationToken, CallbackConfiguration> callbacks;
173
174 // If this CancellationManager has any children, this member points to the
175 // head of a doubly-linked list of its children.
176 CancellationManager* first_child = nullptr; // Not owned.
177 };
178
179 bool RegisterCallbackConfig(CancellationToken token,
180 CallbackConfiguration config);
181
182 bool RegisterChild(CancellationManager* child);
183 void DeregisterChild(CancellationManager* child);
184
185 bool is_cancelling_;
186 std::atomic_bool is_cancelled_;
187 std::atomic<CancellationToken> next_cancellation_token_;
188
189 CancellationManager* const parent_ = nullptr; // Not owned.
190
191 // If this CancellationManager is associated with a parent, this member will
192 // be set to `true` after this is removed from the parent's list of children.
193 bool is_removed_from_parent_ TF_GUARDED_BY(parent_->mu_) = false;
194
195 // If this CancellationManager is associated with a parent, these members form
196 // a doubly-linked list of that parent's children.
197 //
198 // These fields are valid only when `this->is_removed_from_parent_` is false.
199 CancellationManager* prev_sibling_ TF_GUARDED_BY(parent_->mu_) =
200 nullptr; // Not owned.
201 CancellationManager* next_sibling_ TF_GUARDED_BY(parent_->mu_) =
202 nullptr; // Not owned.
203
204 mutex mu_;
205 std::unique_ptr<State> state_ TF_GUARDED_BY(mu_);
206};
207
208// Registers the given cancellation callback, returning a function that can be
209// used to deregister the callback. If `cancellation_manager` is NULL, no
210// registration occurs and `deregister_fn` will be a no-op.
211Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
212 std::function<void()> callback,
213 std::function<void()>* deregister_fn);
214
215} // namespace tsl
216
217#endif // TENSORFLOW_TSL_FRAMEWORK_CANCELLATION_H_
218