1 | /* |
2 | * Copyright (c) Facebook, Inc. and its affiliates. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include <algorithm> |
18 | #include <cstdint> |
19 | #include <limits> |
20 | #include <utility> |
21 | |
22 | #include <glog/logging.h> |
23 | |
24 | namespace folly { |
25 | |
26 | namespace detail { |
27 | |
28 | // Internal cancellation state object. |
29 | class CancellationState { |
30 | public: |
31 | FOLLY_NODISCARD static CancellationStateSourcePtr create(); |
32 | |
33 | private: |
34 | // Constructed initially with a CancellationSource reference count of 1. |
35 | CancellationState() noexcept; |
36 | |
37 | ~CancellationState(); |
38 | |
39 | friend struct CancellationStateTokenDeleter; |
40 | friend struct CancellationStateSourceDeleter; |
41 | |
42 | void removeTokenReference() noexcept; |
43 | void removeSourceReference() noexcept; |
44 | |
45 | public: |
46 | FOLLY_NODISCARD CancellationStateTokenPtr addTokenReference() noexcept; |
47 | |
48 | FOLLY_NODISCARD CancellationStateSourcePtr addSourceReference() noexcept; |
49 | |
50 | bool tryAddCallback( |
51 | CancellationCallback* callback, |
52 | bool incrementRefCountIfSuccessful) noexcept; |
53 | |
54 | void removeCallback(CancellationCallback* callback) noexcept; |
55 | |
56 | bool isCancellationRequested() const noexcept; |
57 | bool canBeCancelled() const noexcept; |
58 | |
59 | // Request cancellation. |
60 | // Return 'true' if cancellation had already been requested. |
61 | // Return 'false' if this was the first thread to request |
62 | // cancellation. |
63 | bool requestCancellation() noexcept; |
64 | |
65 | private: |
66 | void lock() noexcept; |
67 | void unlock() noexcept; |
68 | void unlockAndIncrementTokenCount() noexcept; |
69 | void unlockAndDecrementTokenCount() noexcept; |
70 | bool tryLockAndCancelUnlessCancelled() noexcept; |
71 | |
72 | template <typename Predicate> |
73 | bool tryLock(Predicate predicate) noexcept; |
74 | |
75 | static bool canBeCancelled(std::uint64_t state) noexcept; |
76 | static bool isCancellationRequested(std::uint64_t state) noexcept; |
77 | static bool isLocked(std::uint64_t state) noexcept; |
78 | |
79 | static constexpr std::uint64_t kCancellationRequestedFlag = 1; |
80 | static constexpr std::uint64_t kLockedFlag = 2; |
81 | static constexpr std::uint64_t kTokenReferenceCountIncrement = 4; |
82 | static constexpr std::uint64_t kSourceReferenceCountIncrement = |
83 | std::uint64_t(1) << 33u; |
84 | static constexpr std::uint64_t kTokenReferenceCountMask = |
85 | (kSourceReferenceCountIncrement - 1u) - |
86 | (kTokenReferenceCountIncrement - 1u); |
87 | static constexpr std::uint64_t kSourceReferenceCountMask = |
88 | std::numeric_limits<std::uint64_t>::max() - |
89 | (kSourceReferenceCountIncrement - 1u); |
90 | |
91 | // Bit 0 - Cancellation Requested |
92 | // Bit 1 - Locked Flag |
93 | // Bits 2-32 - Token reference count (max ~2 billion) |
94 | // Bits 33-63 - Source reference count (max ~2 billion) |
95 | std::atomic<std::uint64_t> state_; |
96 | CancellationCallback* head_; |
97 | std::thread::id signallingThreadId_; |
98 | }; |
99 | |
100 | inline void CancellationStateTokenDeleter::operator()( |
101 | CancellationState* state) noexcept { |
102 | state->removeTokenReference(); |
103 | } |
104 | |
105 | inline void CancellationStateSourceDeleter::operator()( |
106 | CancellationState* state) noexcept { |
107 | state->removeSourceReference(); |
108 | } |
109 | |
110 | } // namespace detail |
111 | |
112 | inline CancellationToken::CancellationToken( |
113 | const CancellationToken& other) noexcept |
114 | : state_() { |
115 | if (other.state_) { |
116 | state_ = other.state_->addTokenReference(); |
117 | } |
118 | } |
119 | |
120 | inline CancellationToken::CancellationToken(CancellationToken&& other) noexcept |
121 | : state_(std::move(other.state_)) {} |
122 | |
123 | inline CancellationToken& CancellationToken::operator=( |
124 | const CancellationToken& other) noexcept { |
125 | if (state_ != other.state_) { |
126 | CancellationToken temp{other}; |
127 | swap(temp); |
128 | } |
129 | return *this; |
130 | } |
131 | |
132 | inline CancellationToken& CancellationToken::operator=( |
133 | CancellationToken&& other) noexcept { |
134 | state_ = std::move(other.state_); |
135 | return *this; |
136 | } |
137 | |
138 | inline bool CancellationToken::isCancellationRequested() const noexcept { |
139 | return state_ != nullptr && state_->isCancellationRequested(); |
140 | } |
141 | |
142 | inline bool CancellationToken::canBeCancelled() const noexcept { |
143 | return state_ != nullptr && state_->canBeCancelled(); |
144 | } |
145 | |
146 | inline void CancellationToken::swap(CancellationToken& other) noexcept { |
147 | std::swap(state_, other.state_); |
148 | } |
149 | |
150 | inline CancellationToken::CancellationToken( |
151 | detail::CancellationStateTokenPtr state) noexcept |
152 | : state_(std::move(state)) {} |
153 | |
154 | inline bool operator==( |
155 | const CancellationToken& a, |
156 | const CancellationToken& b) noexcept { |
157 | return a.state_ == b.state_; |
158 | } |
159 | |
160 | inline bool operator!=( |
161 | const CancellationToken& a, |
162 | const CancellationToken& b) noexcept { |
163 | return !(a == b); |
164 | } |
165 | |
166 | inline CancellationSource::CancellationSource() |
167 | : state_(detail::CancellationState::create()) {} |
168 | |
169 | inline CancellationSource::CancellationSource( |
170 | const CancellationSource& other) noexcept |
171 | : state_() { |
172 | if (other.state_) { |
173 | state_ = other.state_->addSourceReference(); |
174 | } |
175 | } |
176 | |
177 | inline CancellationSource::CancellationSource( |
178 | CancellationSource&& other) noexcept |
179 | : state_(std::move(other.state_)) {} |
180 | |
181 | inline CancellationSource& CancellationSource::operator=( |
182 | const CancellationSource& other) noexcept { |
183 | if (state_ != other.state_) { |
184 | CancellationSource temp{other}; |
185 | swap(temp); |
186 | } |
187 | return *this; |
188 | } |
189 | |
190 | inline CancellationSource& CancellationSource::operator=( |
191 | CancellationSource&& other) noexcept { |
192 | state_ = std::move(other.state_); |
193 | return *this; |
194 | } |
195 | |
196 | inline CancellationSource CancellationSource::invalid() noexcept { |
197 | return CancellationSource{detail::CancellationStateSourcePtr{}}; |
198 | } |
199 | |
200 | inline bool CancellationSource::isCancellationRequested() const noexcept { |
201 | return state_ != nullptr && state_->isCancellationRequested(); |
202 | } |
203 | |
204 | inline bool CancellationSource::canBeCancelled() const noexcept { |
205 | return state_ != nullptr; |
206 | } |
207 | |
208 | inline CancellationToken CancellationSource::getToken() const noexcept { |
209 | if (state_ != nullptr) { |
210 | return CancellationToken{state_->addTokenReference()}; |
211 | } |
212 | return CancellationToken{}; |
213 | } |
214 | |
215 | inline bool CancellationSource::requestCancellation() const noexcept { |
216 | if (state_ != nullptr) { |
217 | return state_->requestCancellation(); |
218 | } |
219 | return false; |
220 | } |
221 | |
222 | inline void CancellationSource::swap(CancellationSource& other) noexcept { |
223 | std::swap(state_, other.state_); |
224 | } |
225 | |
226 | inline CancellationSource::CancellationSource( |
227 | detail::CancellationStateSourcePtr&& state) noexcept |
228 | : state_(std::move(state)) {} |
229 | |
230 | template < |
231 | typename Callable, |
232 | std::enable_if_t< |
233 | std::is_constructible<CancellationCallback::VoidFunction, Callable>:: |
234 | value, |
235 | int>> |
236 | inline CancellationCallback::CancellationCallback( |
237 | CancellationToken&& ct, |
238 | Callable&& callable) |
239 | : next_(nullptr), |
240 | prevNext_(nullptr), |
241 | state_(nullptr), |
242 | callback_(static_cast<Callable&&>(callable)), |
243 | destructorHasRunInsideCallback_(nullptr), |
244 | callbackCompleted_(false) { |
245 | if (ct.state_ != nullptr && ct.state_->tryAddCallback(this, false)) { |
246 | state_ = ct.state_.release(); |
247 | } |
248 | } |
249 | |
250 | template < |
251 | typename Callable, |
252 | std::enable_if_t< |
253 | std::is_constructible<CancellationCallback::VoidFunction, Callable>:: |
254 | value, |
255 | int>> |
256 | inline CancellationCallback::CancellationCallback( |
257 | const CancellationToken& ct, |
258 | Callable&& callable) |
259 | : next_(nullptr), |
260 | prevNext_(nullptr), |
261 | state_(nullptr), |
262 | callback_(static_cast<Callable&&>(callable)), |
263 | destructorHasRunInsideCallback_(nullptr), |
264 | callbackCompleted_(false) { |
265 | if (ct.state_ != nullptr && ct.state_->tryAddCallback(this, true)) { |
266 | state_ = ct.state_.get(); |
267 | } |
268 | } |
269 | |
270 | inline CancellationCallback::~CancellationCallback() { |
271 | if (state_ != nullptr) { |
272 | state_->removeCallback(this); |
273 | } |
274 | } |
275 | |
276 | inline void CancellationCallback::invokeCallback() noexcept { |
277 | // Invoke within a noexcept context so that we std::terminate() if it throws. |
278 | callback_(); |
279 | } |
280 | |
281 | namespace detail { |
282 | |
283 | inline CancellationStateSourcePtr CancellationState::create() { |
284 | return CancellationStateSourcePtr{new CancellationState()}; |
285 | } |
286 | |
287 | inline CancellationState::CancellationState() noexcept |
288 | : state_(kSourceReferenceCountIncrement), |
289 | head_(nullptr), |
290 | signallingThreadId_() {} |
291 | |
292 | inline CancellationStateTokenPtr |
293 | CancellationState::addTokenReference() noexcept { |
294 | state_.fetch_add(kTokenReferenceCountIncrement, std::memory_order_relaxed); |
295 | return CancellationStateTokenPtr{this}; |
296 | } |
297 | |
298 | inline void CancellationState::removeTokenReference() noexcept { |
299 | const auto oldState = state_.fetch_sub( |
300 | kTokenReferenceCountIncrement, std::memory_order_acq_rel); |
301 | DCHECK( |
302 | (oldState & kTokenReferenceCountMask) >= kTokenReferenceCountIncrement); |
303 | if (oldState < (2 * kTokenReferenceCountIncrement)) { |
304 | delete this; |
305 | } |
306 | } |
307 | |
308 | inline CancellationStateSourcePtr |
309 | CancellationState::addSourceReference() noexcept { |
310 | state_.fetch_add(kSourceReferenceCountIncrement, std::memory_order_relaxed); |
311 | return CancellationStateSourcePtr{this}; |
312 | } |
313 | |
314 | inline void CancellationState::removeSourceReference() noexcept { |
315 | const auto oldState = state_.fetch_sub( |
316 | kSourceReferenceCountIncrement, std::memory_order_acq_rel); |
317 | DCHECK( |
318 | (oldState & kSourceReferenceCountMask) >= kSourceReferenceCountIncrement); |
319 | if (oldState < |
320 | (kSourceReferenceCountIncrement + kTokenReferenceCountIncrement)) { |
321 | delete this; |
322 | } |
323 | } |
324 | |
325 | inline bool CancellationState::isCancellationRequested() const noexcept { |
326 | return isCancellationRequested(state_.load(std::memory_order_acquire)); |
327 | } |
328 | |
329 | inline bool CancellationState::canBeCancelled() const noexcept { |
330 | return canBeCancelled(state_.load(std::memory_order_acquire)); |
331 | } |
332 | |
333 | inline bool CancellationState::canBeCancelled(std::uint64_t state) noexcept { |
334 | // Can be cancelled if there is at least one CancellationSource ref-count |
335 | // or if cancellation has been requested. |
336 | return (state >= kSourceReferenceCountIncrement) || |
337 | isCancellationRequested(state); |
338 | } |
339 | |
340 | inline bool CancellationState::isCancellationRequested( |
341 | std::uint64_t state) noexcept { |
342 | return (state & kCancellationRequestedFlag) != 0; |
343 | } |
344 | |
345 | inline bool CancellationState::isLocked(std::uint64_t state) noexcept { |
346 | return (state & kLockedFlag) != 0; |
347 | } |
348 | |
349 | } // namespace detail |
350 | |
351 | } // namespace folly |
352 | |