1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_TSL_FRAMEWORK_CANCELLATION_H_ |
17 | #define TENSORFLOW_TSL_FRAMEWORK_CANCELLATION_H_ |
18 | |
19 | #include <atomic> |
20 | #include <functional> |
21 | |
22 | #include "tensorflow/tsl/lib/gtl/flatmap.h" |
23 | #include "tensorflow/tsl/platform/hash.h" |
24 | #include "tensorflow/tsl/platform/mutex.h" |
25 | #include "tensorflow/tsl/platform/notification.h" |
26 | #include "tensorflow/tsl/platform/status.h" |
27 | #include "tensorflow/tsl/platform/stringpiece.h" |
28 | #include "tensorflow/tsl/platform/thread_annotations.h" |
29 | #include "tensorflow/tsl/platform/types.h" |
30 | |
31 | namespace tsl { |
32 | |
33 | // A token that can be used to register and deregister a |
34 | // CancelCallback with a CancellationManager. |
35 | // |
36 | // CancellationToken values must be created by a call to |
37 | // CancellationManager::get_cancellation_token. |
38 | typedef int64_t CancellationToken; |
39 | |
40 | // A callback that is invoked when a step is canceled. |
41 | // |
42 | // NOTE(mrry): See caveats about CancelCallback implementations in the |
43 | // comment for CancellationManager::RegisterCallback. |
44 | typedef std::function<void()> CancelCallback; |
45 | |
46 | // This class should never simultaneously be used as the cancellation manager |
47 | // for two separate sets of executions (i.e two separate steps, or two separate |
48 | // function executions). |
49 | class CancellationManager { |
50 | public: |
51 | // A value that won't be returned by get_cancellation_token(). |
52 | static const CancellationToken kInvalidToken; |
53 | |
54 | CancellationManager(); |
55 | |
56 | // Constructs a new CancellationManager that is a "child" of `*parent`. |
57 | // |
58 | // If `*parent` is cancelled, `*this` will be cancelled. `*parent` must |
59 | // outlive the created CancellationManager. |
60 | explicit CancellationManager(CancellationManager* parent); |
61 | |
62 | ~CancellationManager(); |
63 | |
64 | // Run all callbacks associated with this manager. |
65 | void StartCancel(); |
66 | |
67 | // Run all callbacks associated with this manager with a status. |
68 | // Currently the status is for logging purpose only. See also |
69 | // CancellationManager::RegisterCallbackWithErrorLogging. |
70 | void StartCancelWithStatus(const Status& status); |
71 | |
72 | // Returns true iff StartCancel() has been called. |
73 | bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); } |
74 | |
75 | // Returns a token that must be used in calls to RegisterCallback |
76 | // and DeregisterCallback. |
77 | CancellationToken get_cancellation_token() { |
78 | return next_cancellation_token_.fetch_add(1); |
79 | } |
80 | |
81 | // Attempts to register the given callback to be invoked when this |
82 | // manager is cancelled. Returns true if the callback was |
83 | // registered; returns false if this manager was already cancelled, |
84 | // and the callback was not registered. |
85 | // |
86 | // If this method returns false, it is the caller's responsibility |
87 | // to perform any cancellation cleanup. |
88 | // |
89 | // This method is tricky to use correctly. The following usage pattern |
90 | // is recommended: |
91 | // |
92 | // class ObjectWithCancellableOperation { |
93 | // mutex mu_; |
94 | // void CancellableOperation(CancellationManager* cm, |
95 | // std::function<void(Status)> callback) { |
96 | // bool already_cancelled; |
97 | // CancellationToken token = cm->get_cancellation_token(); |
98 | // { |
99 | // mutex_lock(mu_); |
100 | // already_cancelled = !cm->RegisterCallback( |
101 | // [this, token]() { Cancel(token); }); |
102 | // if (!already_cancelled) { |
103 | // // Issue asynchronous operation. Associate the pending operation |
104 | // // with `token` in some object state, or provide another way for |
105 | // // the Cancel method to look up the operation for cancellation. |
106 | // // Ensure that `cm->DeregisterCallback(token)` is called without |
107 | // // holding `mu_`, before `callback` is invoked. |
108 | // // ... |
109 | // } |
110 | // } |
111 | // if (already_cancelled) { |
112 | // callback(errors::Cancelled("Operation was cancelled")); |
113 | // } |
114 | // } |
115 | // |
116 | // void Cancel(CancellationToken token) { |
117 | // mutex_lock(mu_); |
118 | // // Take action to cancel the operation with the given cancellation |
119 | // // token. |
120 | // } |
121 | // |
122 | // NOTE(mrry): The caller should take care that (i) the calling code |
123 | // is robust to `callback` being invoked asynchronously (e.g. from |
124 | // another thread), (ii) `callback` is deregistered by a call to |
125 | // this->DeregisterCallback(token) when the operation completes |
126 | // successfully, and (iii) `callback` does not invoke any method |
127 | // on this cancellation manager. Furthermore, it is important that |
128 | // the eventual caller of the complementary DeregisterCallback does not |
129 | // hold any mutexes that are required by `callback`. |
130 | bool RegisterCallback(CancellationToken token, CancelCallback callback); |
131 | |
132 | // Similar to RegisterCallback, but if the cancellation manager starts a |
133 | // cancellation with an error status, it will log the error status before |
134 | // invoking the callback. `callback_name` is a human-readable name of the |
135 | // callback, which will be displayed on the log. |
136 | bool RegisterCallbackWithErrorLogging(CancellationToken token, |
137 | CancelCallback callback, |
138 | tsl::StringPiece callback_name); |
139 | |
140 | // Deregister the callback that, when registered, was associated |
141 | // with the given cancellation token. Returns true iff the callback |
142 | // was deregistered and will not be invoked; otherwise returns false |
143 | // after the callback has been invoked, blocking if necessary. |
144 | // |
145 | // NOTE(mrry): This method may block if cancellation is in progress. |
146 | // The caller of this method must not hold any mutexes that are required |
147 | // to invoke any cancellation callback that has been registered with this |
148 | // cancellation manager. |
149 | bool DeregisterCallback(CancellationToken token); |
150 | |
151 | // Deregister the callback that, when registered, was associated |
152 | // with the given cancellation token. Returns true iff the callback |
153 | // was deregistered and will not be invoked; otherwise returns false |
154 | // immediately, with no guarantee that the callback has completed. |
155 | // |
156 | // This method is guaranteed to return true if StartCancel has not been |
157 | // called. |
158 | bool TryDeregisterCallback(CancellationToken token); |
159 | |
160 | // Returns true iff cancellation is in progress. |
161 | bool IsCancelling(); |
162 | |
163 | private: |
164 | struct CallbackConfiguration { |
165 | CancelCallback callback; |
166 | std::string name; |
167 | bool log_error = false; |
168 | }; |
169 | |
170 | struct State { |
171 | Notification cancelled_notification; |
172 | gtl::FlatMap<CancellationToken, CallbackConfiguration> callbacks; |
173 | |
174 | // If this CancellationManager has any children, this member points to the |
175 | // head of a doubly-linked list of its children. |
176 | CancellationManager* first_child = nullptr; // Not owned. |
177 | }; |
178 | |
179 | bool RegisterCallbackConfig(CancellationToken token, |
180 | CallbackConfiguration config); |
181 | |
182 | bool RegisterChild(CancellationManager* child); |
183 | void DeregisterChild(CancellationManager* child); |
184 | |
185 | bool is_cancelling_; |
186 | std::atomic_bool is_cancelled_; |
187 | std::atomic<CancellationToken> next_cancellation_token_; |
188 | |
189 | CancellationManager* const parent_ = nullptr; // Not owned. |
190 | |
191 | // If this CancellationManager is associated with a parent, this member will |
192 | // be set to `true` after this is removed from the parent's list of children. |
193 | bool is_removed_from_parent_ TF_GUARDED_BY(parent_->mu_) = false; |
194 | |
195 | // If this CancellationManager is associated with a parent, these members form |
196 | // a doubly-linked list of that parent's children. |
197 | // |
198 | // These fields are valid only when `this->is_removed_from_parent_` is false. |
199 | CancellationManager* prev_sibling_ TF_GUARDED_BY(parent_->mu_) = |
200 | nullptr; // Not owned. |
201 | CancellationManager* next_sibling_ TF_GUARDED_BY(parent_->mu_) = |
202 | nullptr; // Not owned. |
203 | |
204 | mutex mu_; |
205 | std::unique_ptr<State> state_ TF_GUARDED_BY(mu_); |
206 | }; |
207 | |
208 | // Registers the given cancellation callback, returning a function that can be |
209 | // used to deregister the callback. If `cancellation_manager` is NULL, no |
210 | // registration occurs and `deregister_fn` will be a no-op. |
211 | Status RegisterCancellationCallback(CancellationManager* cancellation_manager, |
212 | std::function<void()> callback, |
213 | std::function<void()>* deregister_fn); |
214 | |
215 | } // namespace tsl |
216 | |
217 | #endif // TENSORFLOW_TSL_FRAMEWORK_CANCELLATION_H_ |
218 | |