1/*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <algorithm>
18#include <cstdint>
19#include <limits>
20#include <utility>
21
22#include <glog/logging.h>
23
24namespace folly {
25
26namespace detail {
27
28// Internal cancellation state object.
29class CancellationState {
30 public:
31 FOLLY_NODISCARD static CancellationStateSourcePtr create();
32
33 private:
34 // Constructed initially with a CancellationSource reference count of 1.
35 CancellationState() noexcept;
36
37 ~CancellationState();
38
39 friend struct CancellationStateTokenDeleter;
40 friend struct CancellationStateSourceDeleter;
41
42 void removeTokenReference() noexcept;
43 void removeSourceReference() noexcept;
44
45 public:
46 FOLLY_NODISCARD CancellationStateTokenPtr addTokenReference() noexcept;
47
48 FOLLY_NODISCARD CancellationStateSourcePtr addSourceReference() noexcept;
49
50 bool tryAddCallback(
51 CancellationCallback* callback,
52 bool incrementRefCountIfSuccessful) noexcept;
53
54 void removeCallback(CancellationCallback* callback) noexcept;
55
56 bool isCancellationRequested() const noexcept;
57 bool canBeCancelled() const noexcept;
58
59 // Request cancellation.
60 // Return 'true' if cancellation had already been requested.
61 // Return 'false' if this was the first thread to request
62 // cancellation.
63 bool requestCancellation() noexcept;
64
65 private:
66 void lock() noexcept;
67 void unlock() noexcept;
68 void unlockAndIncrementTokenCount() noexcept;
69 void unlockAndDecrementTokenCount() noexcept;
70 bool tryLockAndCancelUnlessCancelled() noexcept;
71
72 template <typename Predicate>
73 bool tryLock(Predicate predicate) noexcept;
74
75 static bool canBeCancelled(std::uint64_t state) noexcept;
76 static bool isCancellationRequested(std::uint64_t state) noexcept;
77 static bool isLocked(std::uint64_t state) noexcept;
78
79 static constexpr std::uint64_t kCancellationRequestedFlag = 1;
80 static constexpr std::uint64_t kLockedFlag = 2;
81 static constexpr std::uint64_t kTokenReferenceCountIncrement = 4;
82 static constexpr std::uint64_t kSourceReferenceCountIncrement =
83 std::uint64_t(1) << 33u;
84 static constexpr std::uint64_t kTokenReferenceCountMask =
85 (kSourceReferenceCountIncrement - 1u) -
86 (kTokenReferenceCountIncrement - 1u);
87 static constexpr std::uint64_t kSourceReferenceCountMask =
88 std::numeric_limits<std::uint64_t>::max() -
89 (kSourceReferenceCountIncrement - 1u);
90
91 // Bit 0 - Cancellation Requested
92 // Bit 1 - Locked Flag
93 // Bits 2-32 - Token reference count (max ~2 billion)
94 // Bits 33-63 - Source reference count (max ~2 billion)
95 std::atomic<std::uint64_t> state_;
96 CancellationCallback* head_;
97 std::thread::id signallingThreadId_;
98};
99
100inline void CancellationStateTokenDeleter::operator()(
101 CancellationState* state) noexcept {
102 state->removeTokenReference();
103}
104
105inline void CancellationStateSourceDeleter::operator()(
106 CancellationState* state) noexcept {
107 state->removeSourceReference();
108}
109
110} // namespace detail
111
112inline CancellationToken::CancellationToken(
113 const CancellationToken& other) noexcept
114 : state_() {
115 if (other.state_) {
116 state_ = other.state_->addTokenReference();
117 }
118}
119
120inline CancellationToken::CancellationToken(CancellationToken&& other) noexcept
121 : state_(std::move(other.state_)) {}
122
123inline CancellationToken& CancellationToken::operator=(
124 const CancellationToken& other) noexcept {
125 if (state_ != other.state_) {
126 CancellationToken temp{other};
127 swap(temp);
128 }
129 return *this;
130}
131
132inline CancellationToken& CancellationToken::operator=(
133 CancellationToken&& other) noexcept {
134 state_ = std::move(other.state_);
135 return *this;
136}
137
138inline bool CancellationToken::isCancellationRequested() const noexcept {
139 return state_ != nullptr && state_->isCancellationRequested();
140}
141
142inline bool CancellationToken::canBeCancelled() const noexcept {
143 return state_ != nullptr && state_->canBeCancelled();
144}
145
146inline void CancellationToken::swap(CancellationToken& other) noexcept {
147 std::swap(state_, other.state_);
148}
149
150inline CancellationToken::CancellationToken(
151 detail::CancellationStateTokenPtr state) noexcept
152 : state_(std::move(state)) {}
153
154inline bool operator==(
155 const CancellationToken& a,
156 const CancellationToken& b) noexcept {
157 return a.state_ == b.state_;
158}
159
160inline bool operator!=(
161 const CancellationToken& a,
162 const CancellationToken& b) noexcept {
163 return !(a == b);
164}
165
166inline CancellationSource::CancellationSource()
167 : state_(detail::CancellationState::create()) {}
168
169inline CancellationSource::CancellationSource(
170 const CancellationSource& other) noexcept
171 : state_() {
172 if (other.state_) {
173 state_ = other.state_->addSourceReference();
174 }
175}
176
177inline CancellationSource::CancellationSource(
178 CancellationSource&& other) noexcept
179 : state_(std::move(other.state_)) {}
180
181inline CancellationSource& CancellationSource::operator=(
182 const CancellationSource& other) noexcept {
183 if (state_ != other.state_) {
184 CancellationSource temp{other};
185 swap(temp);
186 }
187 return *this;
188}
189
190inline CancellationSource& CancellationSource::operator=(
191 CancellationSource&& other) noexcept {
192 state_ = std::move(other.state_);
193 return *this;
194}
195
196inline CancellationSource CancellationSource::invalid() noexcept {
197 return CancellationSource{detail::CancellationStateSourcePtr{}};
198}
199
200inline bool CancellationSource::isCancellationRequested() const noexcept {
201 return state_ != nullptr && state_->isCancellationRequested();
202}
203
204inline bool CancellationSource::canBeCancelled() const noexcept {
205 return state_ != nullptr;
206}
207
208inline CancellationToken CancellationSource::getToken() const noexcept {
209 if (state_ != nullptr) {
210 return CancellationToken{state_->addTokenReference()};
211 }
212 return CancellationToken{};
213}
214
215inline bool CancellationSource::requestCancellation() const noexcept {
216 if (state_ != nullptr) {
217 return state_->requestCancellation();
218 }
219 return false;
220}
221
222inline void CancellationSource::swap(CancellationSource& other) noexcept {
223 std::swap(state_, other.state_);
224}
225
226inline CancellationSource::CancellationSource(
227 detail::CancellationStateSourcePtr&& state) noexcept
228 : state_(std::move(state)) {}
229
230template <
231 typename Callable,
232 std::enable_if_t<
233 std::is_constructible<CancellationCallback::VoidFunction, Callable>::
234 value,
235 int>>
236inline CancellationCallback::CancellationCallback(
237 CancellationToken&& ct,
238 Callable&& callable)
239 : next_(nullptr),
240 prevNext_(nullptr),
241 state_(nullptr),
242 callback_(static_cast<Callable&&>(callable)),
243 destructorHasRunInsideCallback_(nullptr),
244 callbackCompleted_(false) {
245 if (ct.state_ != nullptr && ct.state_->tryAddCallback(this, false)) {
246 state_ = ct.state_.release();
247 }
248}
249
250template <
251 typename Callable,
252 std::enable_if_t<
253 std::is_constructible<CancellationCallback::VoidFunction, Callable>::
254 value,
255 int>>
256inline CancellationCallback::CancellationCallback(
257 const CancellationToken& ct,
258 Callable&& callable)
259 : next_(nullptr),
260 prevNext_(nullptr),
261 state_(nullptr),
262 callback_(static_cast<Callable&&>(callable)),
263 destructorHasRunInsideCallback_(nullptr),
264 callbackCompleted_(false) {
265 if (ct.state_ != nullptr && ct.state_->tryAddCallback(this, true)) {
266 state_ = ct.state_.get();
267 }
268}
269
270inline CancellationCallback::~CancellationCallback() {
271 if (state_ != nullptr) {
272 state_->removeCallback(this);
273 }
274}
275
276inline void CancellationCallback::invokeCallback() noexcept {
277 // Invoke within a noexcept context so that we std::terminate() if it throws.
278 callback_();
279}
280
281namespace detail {
282
283inline CancellationStateSourcePtr CancellationState::create() {
284 return CancellationStateSourcePtr{new CancellationState()};
285}
286
287inline CancellationState::CancellationState() noexcept
288 : state_(kSourceReferenceCountIncrement),
289 head_(nullptr),
290 signallingThreadId_() {}
291
292inline CancellationStateTokenPtr
293CancellationState::addTokenReference() noexcept {
294 state_.fetch_add(kTokenReferenceCountIncrement, std::memory_order_relaxed);
295 return CancellationStateTokenPtr{this};
296}
297
298inline void CancellationState::removeTokenReference() noexcept {
299 const auto oldState = state_.fetch_sub(
300 kTokenReferenceCountIncrement, std::memory_order_acq_rel);
301 DCHECK(
302 (oldState & kTokenReferenceCountMask) >= kTokenReferenceCountIncrement);
303 if (oldState < (2 * kTokenReferenceCountIncrement)) {
304 delete this;
305 }
306}
307
308inline CancellationStateSourcePtr
309CancellationState::addSourceReference() noexcept {
310 state_.fetch_add(kSourceReferenceCountIncrement, std::memory_order_relaxed);
311 return CancellationStateSourcePtr{this};
312}
313
314inline void CancellationState::removeSourceReference() noexcept {
315 const auto oldState = state_.fetch_sub(
316 kSourceReferenceCountIncrement, std::memory_order_acq_rel);
317 DCHECK(
318 (oldState & kSourceReferenceCountMask) >= kSourceReferenceCountIncrement);
319 if (oldState <
320 (kSourceReferenceCountIncrement + kTokenReferenceCountIncrement)) {
321 delete this;
322 }
323}
324
325inline bool CancellationState::isCancellationRequested() const noexcept {
326 return isCancellationRequested(state_.load(std::memory_order_acquire));
327}
328
329inline bool CancellationState::canBeCancelled() const noexcept {
330 return canBeCancelled(state_.load(std::memory_order_acquire));
331}
332
333inline bool CancellationState::canBeCancelled(std::uint64_t state) noexcept {
334 // Can be cancelled if there is at least one CancellationSource ref-count
335 // or if cancellation has been requested.
336 return (state >= kSourceReferenceCountIncrement) ||
337 isCancellationRequested(state);
338}
339
340inline bool CancellationState::isCancellationRequested(
341 std::uint64_t state) noexcept {
342 return (state & kCancellationRequestedFlag) != 0;
343}
344
345inline bool CancellationState::isLocked(std::uint64_t state) noexcept {
346 return (state & kLockedFlag) != 0;
347}
348
349} // namespace detail
350
351} // namespace folly
352