1 | #pragma once |
2 | |
3 | #include <c10/util/C++17.h> |
4 | #include <c10/util/Exception.h> |
5 | #include <c10/util/ExclusivelyOwned.h> |
6 | #include <c10/util/MaybeOwned.h> |
7 | #include <atomic> |
8 | #include <climits> |
9 | #include <memory> |
10 | #include <stdexcept> |
11 | |
12 | namespace pybind11 { |
13 | template <typename, typename...> |
14 | class class_; |
15 | } |
16 | |
17 | namespace c10 { |
18 | class intrusive_ptr_target; |
19 | namespace raw { |
20 | namespace weak_intrusive_ptr { |
21 | inline void incref(intrusive_ptr_target* self); |
22 | } |
23 | namespace intrusive_ptr { |
24 | inline void incref(intrusive_ptr_target* self); |
25 | } |
26 | |
27 | // constructor tag used by intrusive_ptr constructors |
28 | struct DontIncreaseRefcount {}; |
29 | } // namespace raw |
30 | /** |
31 | * intrusive_ptr<T> is an alternative to shared_ptr<T> that has better |
32 | * performance because it does the refcounting intrusively |
33 | * (i.e. in a member of the object itself). |
34 | * Your class T needs to inherit from intrusive_ptr_target to allow it to be |
35 | * used in an intrusive_ptr<T>. Your class's constructor should not allow |
36 | *`this` to escape to other threads or create an intrusive_ptr from `this`. |
37 | */ |
38 | |
39 | // Note [Stack allocated intrusive_ptr_target safety] |
40 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
41 | // A well known problem with std::enable_shared_from_this is that it |
42 | // allows you to create a std::shared_ptr from a stack allocated object, |
43 | // which is totally bogus because the object will die once you return |
44 | // from the stack. In intrusive_ptr, we can detect that this has occurred, |
45 | // because we set the refcount/weakcount of objects which inherit from |
46 | // intrusive_ptr_target to zero, *unless* we can prove that the object |
47 | // was dynamically allocated (e.g., via make_intrusive). |
48 | // |
49 | // Thus, whenever you transmute a T* into a intrusive_ptr<T>, we check |
50 | // and make sure that the refcount isn't zero (or, a more subtle |
51 | // test for weak_intrusive_ptr<T>, for which the refcount may validly |
52 | // be zero, but the weak refcount better not be zero), because that |
53 | // tells us if the object was allocated by us. If it wasn't, no |
54 | // intrusive_ptr for you! |
55 | |
56 | class C10_API intrusive_ptr_target { |
57 | // Note [Weak references for intrusive refcounting] |
58 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
59 | // Here's the scheme: |
60 | // |
61 | // - refcount == number of strong references to the object |
62 | // weakcount == number of weak references to the object, |
63 | // plus one more if refcount > 0 |
64 | // An invariant: refcount > 0 => weakcount > 0 |
65 | // |
66 | // - c10::StorageImpl stays live as long as there are any strong |
67 | // or weak pointers to it (weakcount > 0, since strong |
68 | // references count as a +1 to weakcount) |
69 | // |
70 | // - finalizers are called and data_ptr is deallocated when refcount == 0 |
71 | // |
72 | // - Once refcount == 0, it can never again be > 0 (the transition |
73 | // from > 0 to == 0 is monotonic) |
74 | // |
75 | // - When you access c10::StorageImpl via a weak pointer, you must |
76 | // atomically increment the use count, if it is greater than 0. |
77 | // If it is not, you must report that the storage is dead. |
78 | // |
79 | mutable std::atomic<size_t> refcount_; |
80 | mutable std::atomic<size_t> weakcount_; |
81 | |
82 | template <typename T, typename NullType> |
83 | friend class intrusive_ptr; |
84 | friend inline void raw::intrusive_ptr::incref(intrusive_ptr_target* self); |
85 | |
86 | template <typename T, typename NullType> |
87 | friend class weak_intrusive_ptr; |
88 | friend inline void raw::weak_intrusive_ptr::incref( |
89 | intrusive_ptr_target* self); |
90 | |
91 | template <typename T> |
92 | friend struct ExclusivelyOwnedTensorTraits; |
93 | |
94 | protected: |
95 | // protected destructor. We never want to destruct intrusive_ptr_target* |
96 | // directly. |
97 | virtual ~intrusive_ptr_target() { |
98 | // Disable -Wterminate and -Wexceptions so we're allowed to use assertions |
99 | // (i.e. throw exceptions) in a destructor. |
100 | // We also have to disable -Wunknown-warning-option and -Wpragmas, because |
101 | // some other compilers don't know about -Wterminate or -Wexceptions and |
102 | // will show a warning about unknown warning options otherwise. |
103 | #if defined(_MSC_VER) && !defined(__clang__) |
104 | #pragma warning(push) |
105 | #pragma warning( \ |
106 | disable : 4297) // function assumed not to throw an exception but does |
107 | #else |
108 | #pragma GCC diagnostic push |
109 | #pragma GCC diagnostic ignored "-Wpragmas" |
110 | #pragma GCC diagnostic ignored "-Wunknown-warning-option" |
111 | #pragma GCC diagnostic ignored "-Wterminate" |
112 | #pragma GCC diagnostic ignored "-Wexceptions" |
113 | #endif |
114 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
115 | // Second condition is there to accommodate |
116 | // unsafe_adapt_non_heap_allocated: since we are doing our own |
117 | // deallocation in that case, it is correct for each |
118 | // expected_decref to have happened (some user code tried to |
119 | // decref and thus free the object, but it didn't happen right |
120 | // away) or not (no user code tried to free the object, and |
121 | // now it's getting destroyed through whatever mechanism the |
122 | // caller of unsafe_adapt_non_heap_allocated wanted to |
123 | // use). We choose our reference count such that the count |
124 | // will not dip below INT_MAX regardless. |
125 | refcount_.load() == 0 || refcount_.load() >= INT_MAX, |
126 | "Tried to destruct an intrusive_ptr_target that still has intrusive_ptr to it; refcount was " , |
127 | refcount_.load()); |
128 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
129 | // See ~intrusive_ptr for optimization that will frequently result in 1 |
130 | // at destruction time. |
131 | weakcount_.load() == 1 || weakcount_.load() == 0 || |
132 | weakcount_.load() == INT_MAX - 1 || weakcount_.load() == INT_MAX, |
133 | "Tried to destruct an intrusive_ptr_target that still has weak_intrusive_ptr to it" ); |
134 | #if defined(_MSC_VER) && !defined(__clang__) |
135 | #pragma warning(pop) |
136 | #else |
137 | #pragma GCC diagnostic pop |
138 | #endif |
139 | } |
140 | |
141 | constexpr intrusive_ptr_target() noexcept : refcount_(0), weakcount_(0) {} |
142 | |
143 | // intrusive_ptr_target supports copy and move: but refcount and weakcount |
144 | // don't participate (since they are intrinsic properties of the memory |
145 | // location) |
146 | intrusive_ptr_target(intrusive_ptr_target&& /*other*/) noexcept |
147 | : intrusive_ptr_target() {} |
148 | |
149 | intrusive_ptr_target& operator=(intrusive_ptr_target&& /*other*/) noexcept { |
150 | return *this; |
151 | } |
152 | |
153 | intrusive_ptr_target(const intrusive_ptr_target& /*other*/) noexcept |
154 | : intrusive_ptr_target() {} |
155 | |
156 | intrusive_ptr_target& operator=( |
157 | const intrusive_ptr_target& /*other*/) noexcept { |
158 | return *this; |
159 | } |
160 | |
161 | private: |
162 | /** |
163 | * This is called when refcount reaches zero. |
164 | * You can override this to release expensive resources. |
165 | * There might still be weak references, so your object might not get |
166 | * destructed yet, but you can assume the object isn't used anymore, |
167 | * i.e. no more calls to methods or accesses to members (we just can't |
168 | * destruct it yet because we need the weakcount accessible). |
169 | * |
170 | * If there are no weak references (i.e. your class is about to be |
171 | * destructed), this function WILL NOT be called. |
172 | */ |
173 | virtual void release_resources() {} |
174 | }; |
175 | |
176 | namespace detail { |
177 | template <class TTarget> |
178 | struct intrusive_target_default_null_type final { |
179 | static constexpr TTarget* singleton() noexcept { |
180 | return nullptr; |
181 | } |
182 | }; |
183 | |
184 | template <class TTarget, class ToNullType, class FromNullType> |
185 | TTarget* assign_ptr_(TTarget* rhs) { |
186 | if (FromNullType::singleton() == rhs) { |
187 | return ToNullType::singleton(); |
188 | } else { |
189 | return rhs; |
190 | } |
191 | } |
192 | |
193 | // Increment needs to be acquire-release to make use_count() and |
194 | // unique() reliable. |
195 | inline size_t atomic_refcount_increment(std::atomic<size_t>& refcount) { |
196 | return refcount.fetch_add(1, std::memory_order_acq_rel) + 1; |
197 | } |
198 | |
199 | // weak_use_count() is only used for testing, so we don't need it to |
200 | // be reliable. Relaxed should be fine. |
201 | inline size_t atomic_weakcount_increment(std::atomic<size_t>& weakcount) { |
202 | return weakcount.fetch_add(1, std::memory_order_relaxed) + 1; |
203 | } |
204 | |
205 | // Both decrements need to be acquire-release for correctness. See |
206 | // e.g. std::shared_ptr implementation. |
207 | inline size_t atomic_refcount_decrement(std::atomic<size_t>& refcount) { |
208 | return refcount.fetch_sub(1, std::memory_order_acq_rel) - 1; |
209 | } |
210 | |
211 | inline size_t atomic_weakcount_decrement(std::atomic<size_t>& weakcount) { |
212 | return weakcount.fetch_sub(1, std::memory_order_acq_rel) - 1; |
213 | } |
214 | |
215 | } // namespace detail |
216 | |
217 | template <class TTarget, class NullType> |
218 | class weak_intrusive_ptr; |
219 | |
220 | template < |
221 | class TTarget, |
222 | class NullType = detail::intrusive_target_default_null_type<TTarget>> |
223 | class intrusive_ptr final { |
224 | private: |
225 | // the following static assert would be nice to have but it requires |
226 | // the target class T to be fully defined when intrusive_ptr<T> is instantiated |
227 | // this is a problem for classes that contain pointers to themselves |
228 | // static_assert( |
229 | // std::is_base_of<intrusive_ptr_target, TTarget>::value, |
230 | // "intrusive_ptr can only be used for classes that inherit from |
231 | // intrusive_ptr_target."); |
232 | #ifndef _WIN32 |
233 | // This static_assert triggers on MSVC |
234 | // error C2131: expression did not evaluate to a constant |
235 | static_assert( |
236 | NullType::singleton() == NullType::singleton(), |
237 | "NullType must have a constexpr singleton() method" ); |
238 | #endif |
239 | static_assert( |
240 | std::is_base_of< |
241 | TTarget, |
242 | typename std::remove_pointer<decltype(NullType::singleton())>::type>:: |
243 | value, |
244 | "NullType::singleton() must return a element_type* pointer" ); |
245 | |
246 | TTarget* target_; |
247 | |
248 | template <typename T> |
249 | friend struct ExclusivelyOwnedTensorTraits; |
250 | template <class TTarget2, class NullType2> |
251 | friend class intrusive_ptr; |
252 | friend class weak_intrusive_ptr<TTarget, NullType>; |
253 | |
254 | // Make pybind11::class_ be a friend class of intrusive_ptr, so that custom |
255 | // smart holder in pybind11 could access the private constructor of |
256 | // intrusive_ptr(T*) which took the ownership of the object. This is required |
257 | // by customer holder macro PYBIND11_DECLARE_HOLDER_TYPE, where it uses |
258 | // intrusive_ptr(TTarget*) to initialize and take ownership of the object. For |
259 | // details, see |
260 | // https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#custom-smart-pointers |
261 | template <typename, typename...> |
262 | friend class pybind11::class_; |
263 | |
264 | void retain_() { |
265 | if (target_ != NullType::singleton()) { |
266 | size_t new_refcount = |
267 | detail::atomic_refcount_increment(target_->refcount_); |
268 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
269 | new_refcount != 1, |
270 | "intrusive_ptr: Cannot increase refcount after it reached zero." ); |
271 | } |
272 | } |
273 | |
274 | void reset_() noexcept { |
275 | if (target_ != NullType::singleton() && |
276 | detail::atomic_refcount_decrement(target_->refcount_) == 0) { |
277 | // See comment above about weakcount. As long as refcount>0, |
278 | // weakcount is one larger than the actual number of weak references. |
279 | // So we need to decrement it here. |
280 | bool should_delete = |
281 | target_->weakcount_.load(std::memory_order_acquire) == 1; |
282 | if (!should_delete) { |
283 | // justification for const_cast: release_resources is basically a |
284 | // destructor and a destructor always mutates the object, even for const |
285 | // objects. NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
286 | const_cast<std::remove_const_t<TTarget>*>(target_)->release_resources(); |
287 | should_delete = |
288 | detail::atomic_weakcount_decrement(target_->weakcount_) == 0; |
289 | } |
290 | if (should_delete) { |
291 | delete target_; |
292 | } |
293 | } |
294 | } |
295 | |
296 | // raw pointer constructors are not public because we shouldn't make |
297 | // intrusive_ptr out of raw pointers except from inside the make_intrusive(), |
298 | // reclaim() and weak_intrusive_ptr::lock() implementations. |
299 | |
300 | // This constructor will increase the ref counter for you. |
301 | // This constructor will be used by the make_intrusive(), and also pybind11, |
302 | // which wrap the intrusive_ptr holder around the raw pointer and incref |
303 | // correspondingly (pybind11 requires raw pointer constructor to incref by |
304 | // default). |
305 | explicit intrusive_ptr(TTarget* target) |
306 | : intrusive_ptr(target, raw::DontIncreaseRefcount{}) { |
307 | if (target_ != NullType::singleton()) { |
308 | // We just created result.target_, so we know no other thread has |
309 | // access to it, so we know we needn't care about memory ordering. |
310 | // (On x86_64, a store with memory_order_relaxed generates a plain old |
311 | // `mov`, whereas an atomic increment does a lock-prefixed `add`, which is |
312 | // much more expensive: https://godbolt.org/z/eKPzj8.) |
313 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
314 | target_->refcount_ == 0 && target_->weakcount_ == 0, |
315 | "intrusive_ptr: Newly-created target had non-zero refcounts. Does its " |
316 | "constructor do something strange like incref or create an " |
317 | "intrusive_ptr from `this`?" ); |
318 | target_->refcount_.store(1, std::memory_order_relaxed); |
319 | target_->weakcount_.store(1, std::memory_order_relaxed); |
320 | } |
321 | } |
322 | |
323 | public: |
324 | using element_type = TTarget; |
325 | |
326 | intrusive_ptr() noexcept |
327 | : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {} |
328 | |
329 | intrusive_ptr(std::nullptr_t) noexcept |
330 | : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {} |
331 | |
332 | // This constructor will not increase the ref counter for you. |
333 | // We use the tagged dispatch mechanism to explicitly mark this constructor |
334 | // to not increase the refcount |
335 | explicit intrusive_ptr(TTarget* target, raw::DontIncreaseRefcount) noexcept |
336 | : target_(target) {} |
337 | |
338 | explicit intrusive_ptr(std::unique_ptr<TTarget> rhs) noexcept |
339 | : intrusive_ptr(rhs.release()) {} |
340 | |
341 | intrusive_ptr(intrusive_ptr&& rhs) noexcept : target_(rhs.target_) { |
342 | rhs.target_ = NullType::singleton(); |
343 | } |
344 | |
345 | template <class From, class FromNullType> |
346 | /* implicit */ intrusive_ptr(intrusive_ptr<From, FromNullType>&& rhs) noexcept |
347 | : target_( |
348 | detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) { |
349 | static_assert( |
350 | std::is_convertible<From*, TTarget*>::value, |
351 | "Type mismatch. intrusive_ptr move constructor got pointer of wrong type." ); |
352 | rhs.target_ = FromNullType::singleton(); |
353 | } |
354 | |
355 | intrusive_ptr(const intrusive_ptr& rhs) : target_(rhs.target_) { |
356 | retain_(); |
357 | } |
358 | |
359 | template <class From, class FromNullType> |
360 | /* implicit */ intrusive_ptr(const intrusive_ptr<From, FromNullType>& rhs) |
361 | : target_( |
362 | detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) { |
363 | static_assert( |
364 | std::is_convertible<From*, TTarget*>::value, |
365 | "Type mismatch. intrusive_ptr copy constructor got pointer of wrong type." ); |
366 | retain_(); |
367 | } |
368 | |
369 | ~intrusive_ptr() noexcept { |
370 | reset_(); |
371 | } |
372 | |
373 | intrusive_ptr& operator=(intrusive_ptr&& rhs) & noexcept { |
374 | return operator=<TTarget, NullType>(std::move(rhs)); |
375 | } |
376 | |
377 | template <class From, class FromNullType> |
378 | intrusive_ptr& operator=(intrusive_ptr<From, FromNullType>&& rhs) & noexcept { |
379 | static_assert( |
380 | std::is_convertible<From*, TTarget*>::value, |
381 | "Type mismatch. intrusive_ptr move assignment got pointer of wrong type." ); |
382 | intrusive_ptr tmp = std::move(rhs); |
383 | swap(tmp); |
384 | return *this; |
385 | } |
386 | |
387 | intrusive_ptr& operator=(const intrusive_ptr& rhs) & noexcept { |
388 | return operator=<TTarget, NullType>(rhs); |
389 | } |
390 | |
391 | template <class From, class FromNullType> |
392 | intrusive_ptr& operator=(const intrusive_ptr<From, NullType>& rhs) & { |
393 | static_assert( |
394 | std::is_convertible<From*, TTarget*>::value, |
395 | "Type mismatch. intrusive_ptr copy assignment got pointer of wrong type." ); |
396 | intrusive_ptr tmp = rhs; |
397 | swap(tmp); |
398 | return *this; |
399 | } |
400 | |
401 | TTarget* get() const noexcept { |
402 | return target_; |
403 | } |
404 | |
405 | TTarget& operator*() const noexcept { |
406 | return *target_; |
407 | } |
408 | |
409 | TTarget* operator->() const noexcept { |
410 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDelete) |
411 | return target_; |
412 | } |
413 | |
414 | operator bool() const noexcept { |
415 | return target_ != NullType::singleton(); |
416 | } |
417 | |
418 | void reset() noexcept { |
419 | reset_(); |
420 | target_ = NullType::singleton(); |
421 | } |
422 | |
423 | void swap(intrusive_ptr& rhs) noexcept { |
424 | TTarget* tmp = target_; |
425 | target_ = rhs.target_; |
426 | rhs.target_ = tmp; |
427 | } |
428 | |
429 | // We do a lot of null-pointer checks in our code, good to have this be cheap. |
430 | bool defined() const noexcept { |
431 | return target_ != NullType::singleton(); |
432 | } |
433 | |
434 | size_t use_count() const noexcept { |
435 | if (target_ == NullType::singleton()) { |
436 | return 0; |
437 | } |
438 | return target_->refcount_.load(std::memory_order_acquire); |
439 | } |
440 | |
441 | size_t weak_use_count() const noexcept { |
442 | if (target_ == NullType::singleton()) { |
443 | return 0; |
444 | } |
445 | return target_->weakcount_.load(std::memory_order_acquire); |
446 | } |
447 | |
448 | bool unique() const noexcept { |
449 | return use_count() == 1; |
450 | } |
451 | |
452 | /** |
453 | * Returns an owning (!) pointer to the underlying object and makes the |
454 | * intrusive_ptr instance invalid. That means the refcount is not decreased. |
455 | * You *must* put the returned pointer back into a intrusive_ptr using |
456 | * intrusive_ptr::reclaim(ptr) to properly destruct it. |
457 | * This is helpful for C APIs. |
458 | */ |
459 | TTarget* release() noexcept { |
460 | // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign) |
461 | TTarget* result = target_; |
462 | target_ = NullType::singleton(); |
463 | return result; |
464 | } |
465 | |
466 | /** |
467 | * Takes an owning pointer to TTarget* and creates an intrusive_ptr that takes |
468 | * over ownership. That means the refcount is not increased. |
469 | * This is the counter-part to intrusive_ptr::release() and the pointer |
470 | * passed in *must* have been created using intrusive_ptr::release(). |
471 | */ |
472 | static intrusive_ptr reclaim(TTarget* owning_ptr) { |
473 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
474 | owning_ptr == NullType::singleton() || |
475 | owning_ptr->refcount_.load() == 0 || owning_ptr->weakcount_.load(), |
476 | "TTarget violates the invariant that refcount > 0 => weakcount > 0" ); |
477 | return intrusive_ptr(owning_ptr, raw::DontIncreaseRefcount{}); |
478 | } |
479 | |
480 | /** |
481 | * Takes an owning pointer to TTarget* and creates an intrusive_ptr |
482 | * representing a new reference, i.e. the raw pointer retains |
483 | * ownership. |
484 | */ |
485 | static intrusive_ptr reclaim_copy(TTarget* owning_ptr) { |
486 | auto ret = reclaim(owning_ptr); |
487 | ret.retain_(); |
488 | return ret; |
489 | } |
490 | |
491 | /** |
492 | * Allocate a heap object with args and wrap it inside a intrusive_ptr and |
493 | * incref. This is a helper function to let make_intrusive() access private |
494 | * intrusive_ptr constructors. |
495 | */ |
496 | template <class... Args> |
497 | static intrusive_ptr make(Args&&... args) { |
498 | return intrusive_ptr(new TTarget(std::forward<Args>(args)...)); |
499 | } |
500 | |
501 | /** |
502 | * Turn a new instance of TTarget (e.g., literally allocated |
503 | * using new TTarget(...) into an intrusive_ptr. If possible, |
504 | * use intrusive_ptr::make instead which statically guarantees |
505 | * that the allocation was done properly. |
506 | * |
507 | * At the moment, the only reason this method exists is because |
508 | * pybind11 holder types expect to be able to allocate in |
509 | * this way (because pybind11 handles the new allocation itself). |
510 | */ |
511 | static intrusive_ptr unsafe_steal_from_new(TTarget* raw_ptr) { |
512 | return intrusive_ptr(raw_ptr); |
513 | } |
514 | |
515 | /** |
516 | * Turn an instance of TTarget that should not be reference counted |
517 | * (e.g., allocated into an arena with placement new) into an |
518 | * intrusive_ptr. This is gratuitously unsafe and should only be |
519 | * used if you can guarantee that the pointer will not escape and be |
520 | * refcounted as normal. |
521 | * |
522 | * `expected_decrefs` is a debugging parameter: it indicates the |
523 | * number of strong owners the intrusive_ptr_target in question is |
524 | * expected to get. In most use cases, this will likely be 1. |
525 | * |
526 | * The reason this method exists is for manually sharing |
527 | * StorageImpls across Tensors in the static runtime. It needs |
528 | * access to private intrusive_ptr members so that the refcounts can |
529 | * be initialized to custom values. |
530 | */ |
531 | static intrusive_ptr unsafe_adapt_non_heap_allocated( |
532 | TTarget* raw_ptr, |
533 | size_t expected_decrefs) { |
534 | intrusive_ptr result(raw_ptr, raw::DontIncreaseRefcount{}); |
535 | // INT_MAX is impractically huge for a reference count, while |
536 | // being in no danger of overflowing size_t. We actually only need to |
537 | // initialize the refcount to 2 -- we are just doing an unbalanced |
538 | // incref to prevent the non-heap-allocated target from being |
539 | // freed, and we are optimizing that incref by directly |
540 | // initializing the refcounts rather than doing an expensive |
541 | // atomic increment. The reason to use INT_MAX is to accommodate |
542 | // the debug assertions in ~intrusive_ptr_target. |
543 | #ifdef NDEBUG |
544 | expected_decrefs = 0; |
545 | #endif |
546 | result.target_->refcount_.store( |
547 | INT_MAX + expected_decrefs, std::memory_order_relaxed); |
548 | result.target_->weakcount_.store(INT_MAX, std::memory_order_relaxed); |
549 | return result; |
550 | } |
551 | |
552 | /** |
553 | * Turn a **non-owning raw pointer** to an intrusive_ptr. It is |
554 | * the moral equivalent of enable_shared_from_this on a shared pointer. |
555 | * |
556 | * This method is only valid for objects that are already live. If |
557 | * you are looking for the moral equivalent of unique_ptr<T>(T*) |
558 | * constructor, see steal_from_new. |
559 | * |
560 | * TODO: https://github.com/pytorch/pytorch/issues/56482 |
561 | */ |
562 | static intrusive_ptr unsafe_reclaim_from_nonowning(TTarget* raw_ptr) { |
563 | // See Note [Stack allocated intrusive_ptr_target safety] |
564 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
565 | raw_ptr == NullType::singleton() || raw_ptr->refcount_.load() > 0, |
566 | "intrusive_ptr: Can only reclaim pointers that are owned by someone" ); |
567 | auto ptr = reclaim(raw_ptr); // doesn't increase refcount |
568 | ptr.retain_(); |
569 | return ptr; |
570 | } |
571 | }; |
572 | |
573 | template < |
574 | class TTarget, |
575 | class NullType = detail::intrusive_target_default_null_type<TTarget>, |
576 | class... Args> |
577 | inline intrusive_ptr<TTarget, NullType> make_intrusive(Args&&... args) { |
578 | return intrusive_ptr<TTarget, NullType>::make(std::forward<Args>(args)...); |
579 | } |
580 | |
581 | template <class TTarget, class NullType> |
582 | inline void swap( |
583 | intrusive_ptr<TTarget, NullType>& lhs, |
584 | intrusive_ptr<TTarget, NullType>& rhs) noexcept { |
585 | lhs.swap(rhs); |
586 | } |
587 | |
588 | // To allow intrusive_ptr inside std::map or std::set, we need operator< |
589 | template <class TTarget1, class NullType1, class TTarget2, class NullType2> |
590 | inline bool operator<( |
591 | const intrusive_ptr<TTarget1, NullType1>& lhs, |
592 | const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept { |
593 | return lhs.get() < rhs.get(); |
594 | } |
595 | |
596 | template <class TTarget1, class NullType1, class TTarget2, class NullType2> |
597 | inline bool operator==( |
598 | const intrusive_ptr<TTarget1, NullType1>& lhs, |
599 | const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept { |
600 | return lhs.get() == rhs.get(); |
601 | } |
602 | |
603 | template <class TTarget1, class NullType1> |
604 | inline bool operator==( |
605 | const intrusive_ptr<TTarget1, NullType1>& lhs, |
606 | std::nullptr_t) noexcept { |
607 | return lhs.get() == nullptr; |
608 | } |
609 | |
610 | template <class TTarget2, class NullType2> |
611 | inline bool operator==( |
612 | std::nullptr_t, |
613 | const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept { |
614 | return nullptr == rhs.get(); |
615 | } |
616 | |
617 | template <class TTarget1, class NullType1, class TTarget2, class NullType2> |
618 | inline bool operator!=( |
619 | const intrusive_ptr<TTarget1, NullType1>& lhs, |
620 | const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept { |
621 | return !operator==(lhs, rhs); |
622 | } |
623 | |
624 | template <class TTarget1, class NullType1> |
625 | inline bool operator!=( |
626 | const intrusive_ptr<TTarget1, NullType1>& lhs, |
627 | std::nullptr_t) noexcept { |
628 | return !operator==(lhs, nullptr); |
629 | } |
630 | |
631 | template <class TTarget2, class NullType2> |
632 | inline bool operator!=( |
633 | std::nullptr_t, |
634 | const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept { |
635 | return !operator==(nullptr, rhs); |
636 | } |
637 | template <typename T> |
638 | struct MaybeOwnedTraits<c10::intrusive_ptr<T>> { |
639 | using owned_type = c10::intrusive_ptr<T>; |
640 | using borrow_type = c10::intrusive_ptr<T>; |
641 | |
642 | static borrow_type createBorrow(const owned_type& from) { |
643 | return borrow_type::reclaim(from.get()); |
644 | } |
645 | |
646 | static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) { |
647 | lhs.release(); |
648 | lhs = borrow_type::reclaim(rhs.get()); |
649 | } |
650 | |
651 | static void destroyBorrow(borrow_type& toDestroy) { |
652 | toDestroy.release(); |
653 | } |
654 | |
655 | static const owned_type& referenceFromBorrow(const borrow_type& borrow) { |
656 | return borrow; |
657 | } |
658 | |
659 | static const owned_type* pointerFromBorrow(const borrow_type& borrow) { |
660 | return &borrow; |
661 | } |
662 | |
663 | static bool debugBorrowIsValid(const borrow_type& /*borrow*/) { |
664 | return true; |
665 | } |
666 | }; |
667 | |
668 | template < |
669 | typename TTarget, |
670 | class NullType = detail::intrusive_target_default_null_type<TTarget>> |
671 | class weak_intrusive_ptr final { |
672 | private: |
673 | static_assert( |
674 | std::is_base_of<intrusive_ptr_target, TTarget>::value, |
675 | "intrusive_ptr can only be used for classes that inherit from intrusive_ptr_target." ); |
676 | #ifndef _WIN32 |
677 | // This static_assert triggers on MSVC |
678 | // error C2131: expression did not evaluate to a constant |
679 | static_assert( |
680 | NullType::singleton() == NullType::singleton(), |
681 | "NullType must have a constexpr singleton() method" ); |
682 | #endif |
683 | static_assert( |
684 | std::is_base_of< |
685 | TTarget, |
686 | typename std::remove_pointer<decltype(NullType::singleton())>::type>:: |
687 | value, |
688 | "NullType::singleton() must return a element_type* pointer" ); |
689 | |
690 | TTarget* target_; |
691 | |
692 | template <class TTarget2, class NullType2> |
693 | friend class weak_intrusive_ptr; |
694 | |
695 | void retain_() { |
696 | if (target_ != NullType::singleton()) { |
697 | size_t new_weakcount = |
698 | detail::atomic_weakcount_increment(target_->weakcount_); |
699 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
700 | new_weakcount != 1, |
701 | "weak_intrusive_ptr: Cannot increase weakcount after it reached zero." ); |
702 | } |
703 | } |
704 | |
705 | void reset_() noexcept { |
706 | if (target_ != NullType::singleton() && |
707 | detail::atomic_weakcount_decrement(target_->weakcount_) == 0) { |
708 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDelete) |
709 | delete target_; |
710 | } |
711 | target_ = NullType::singleton(); |
712 | } |
713 | |
714 | constexpr explicit weak_intrusive_ptr(TTarget* target) : target_(target) {} |
715 | |
716 | public: |
717 | using element_type = TTarget; |
718 | |
719 | explicit weak_intrusive_ptr(const intrusive_ptr<TTarget, NullType>& ptr) |
720 | : weak_intrusive_ptr(ptr.get()) { |
721 | retain_(); |
722 | } |
723 | |
724 | weak_intrusive_ptr(weak_intrusive_ptr&& rhs) noexcept : target_(rhs.target_) { |
725 | rhs.target_ = NullType::singleton(); |
726 | } |
727 | |
728 | template <class From, class FromNullType> |
729 | /* implicit */ weak_intrusive_ptr( |
730 | weak_intrusive_ptr<From, FromNullType>&& rhs) noexcept |
731 | : target_( |
732 | detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) { |
733 | static_assert( |
734 | std::is_convertible<From*, TTarget*>::value, |
735 | "Type mismatch. weak_intrusive_ptr move constructor got pointer of wrong type." ); |
736 | rhs.target_ = FromNullType::singleton(); |
737 | } |
738 | |
739 | weak_intrusive_ptr(const weak_intrusive_ptr& rhs) : target_(rhs.target_) { |
740 | retain_(); |
741 | } |
742 | |
743 | template <class From, class FromNullType> |
744 | /* implicit */ weak_intrusive_ptr( |
745 | const weak_intrusive_ptr<From, FromNullType>& rhs) |
746 | : target_( |
747 | detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) { |
748 | static_assert( |
749 | std::is_convertible<From*, TTarget*>::value, |
750 | "Type mismatch. weak_intrusive_ptr copy constructor got pointer of wrong type." ); |
751 | retain_(); |
752 | } |
753 | |
754 | ~weak_intrusive_ptr() noexcept { |
755 | reset_(); |
756 | } |
757 | |
758 | weak_intrusive_ptr& operator=(weak_intrusive_ptr&& rhs) & noexcept { |
759 | return operator=<TTarget, NullType>(std::move(rhs)); |
760 | } |
761 | |
762 | template <class From, class FromNullType> |
763 | weak_intrusive_ptr& operator=( |
764 | weak_intrusive_ptr<From, FromNullType>&& rhs) & noexcept { |
765 | static_assert( |
766 | std::is_convertible<From*, TTarget*>::value, |
767 | "Type mismatch. weak_intrusive_ptr move assignment got pointer of wrong type." ); |
768 | weak_intrusive_ptr tmp = std::move(rhs); |
769 | swap(tmp); |
770 | return *this; |
771 | } |
772 | |
773 | weak_intrusive_ptr& operator=(const weak_intrusive_ptr& rhs) & noexcept { |
774 | return operator=<TTarget, NullType>(rhs); |
775 | } |
776 | |
777 | weak_intrusive_ptr& operator=( |
778 | const intrusive_ptr<TTarget, NullType>& rhs) & noexcept { |
779 | weak_intrusive_ptr tmp(rhs); |
780 | swap(tmp); |
781 | return *this; |
782 | } |
783 | |
784 | template <class From, class FromNullType> |
785 | weak_intrusive_ptr& operator=( |
786 | const weak_intrusive_ptr<From, NullType>& rhs) & { |
787 | static_assert( |
788 | std::is_convertible<From*, TTarget*>::value, |
789 | "Type mismatch. weak_intrusive_ptr copy assignment got pointer of wrong type." ); |
790 | weak_intrusive_ptr tmp = rhs; |
791 | swap(tmp); |
792 | return *this; |
793 | } |
794 | |
795 | void reset() noexcept { |
796 | reset_(); |
797 | } |
798 | |
799 | void swap(weak_intrusive_ptr& rhs) noexcept { |
800 | TTarget* tmp = target_; |
801 | target_ = rhs.target_; |
802 | rhs.target_ = tmp; |
803 | } |
804 | |
805 | // NB: This should ONLY be used by the std::hash implementation |
806 | // for weak_intrusive_ptr. Another way you could do this is |
807 | // friend std::hash<weak_intrusive_ptr>, but this triggers two |
808 | // bugs: |
809 | // |
810 | // (1) It triggers an nvcc bug, where std::hash in a friend class |
811 | // declaration gets preprocessed into hash, which then cannot |
812 | // actually be found. The error in this case looks like: |
813 | // |
814 | // error: no template named 'hash'; did you mean 'std::hash'? |
815 | // |
816 | // (2) On OS X, std::hash is declared as a struct, not a class. |
817 | // This twings: |
818 | // |
819 | // error: class 'hash' was previously declared as a struct |
820 | // [-Werror,-Wmismatched-tags] |
821 | // |
822 | // Both of these are work-aroundable, but on the whole, I decided |
823 | // it would be simpler and easier to make work if we just expose |
824 | // an unsafe getter for target_ |
825 | // |
826 | TTarget* _unsafe_get_target() const noexcept { |
827 | return target_; |
828 | } |
829 | |
830 | size_t use_count() const noexcept { |
831 | if (target_ == NullType::singleton()) { |
832 | return 0; |
833 | } |
834 | return target_->refcount_.load( |
835 | std::memory_order_acquire); // refcount, not weakcount! |
836 | } |
837 | |
838 | size_t weak_use_count() const noexcept { |
839 | if (target_ == NullType::singleton()) { |
840 | return 0; |
841 | } |
842 | return target_->weakcount_.load(std::memory_order_acquire); |
843 | } |
844 | |
845 | bool expired() const noexcept { |
846 | return use_count() == 0; |
847 | } |
848 | |
849 | intrusive_ptr<TTarget, NullType> lock() const noexcept { |
850 | if (expired()) { |
851 | return intrusive_ptr<TTarget, NullType>(); |
852 | } else { |
853 | auto refcount = target_->refcount_.load(std::memory_order_seq_cst); |
854 | do { |
855 | if (refcount == 0) { |
856 | // Object already destructed, no strong references left anymore. |
857 | // Return nullptr. |
858 | return intrusive_ptr<TTarget, NullType>(); |
859 | } |
860 | } while ( |
861 | !target_->refcount_.compare_exchange_weak(refcount, refcount + 1)); |
862 | return intrusive_ptr<TTarget, NullType>( |
863 | target_, raw::DontIncreaseRefcount{}); |
864 | } |
865 | } |
866 | |
867 | /** |
868 | * Returns an owning (but still only weakly referenced) pointer to the |
869 | * underlying object and makes the weak_intrusive_ptr instance invalid. |
870 | * That means the weakcount is not decreased. |
871 | * You *must* put the returned pointer back into a weak_intrusive_ptr using |
872 | * weak_intrusive_ptr::reclaim(ptr) to properly destruct it. |
873 | * This is helpful for C APIs. |
874 | */ |
875 | TTarget* release() noexcept { |
876 | TTarget* result = target_; |
877 | target_ = NullType::singleton(); |
878 | return result; |
879 | } |
880 | |
881 | /** |
882 | * Takes an owning (but must be weakly referenced) pointer to TTarget* and |
883 | * creates a weak_intrusive_ptr that takes over ownership. |
884 | * This means that the weakcount is not increased. |
885 | * This is the counter-part to weak_intrusive_ptr::release() and the pointer |
886 | * passed in *must* have been created using weak_intrusive_ptr::release(). |
887 | */ |
888 | static weak_intrusive_ptr reclaim(TTarget* owning_weak_ptr) { |
889 | // See Note [Stack allocated intrusive_ptr_target safety] |
890 | // if refcount > 0, weakcount must be >1 for weak references to exist. |
891 | // see weak counting explanation at top of this file. |
892 | // if refcount == 0, weakcount only must be >0. |
893 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
894 | owning_weak_ptr == NullType::singleton() || |
895 | owning_weak_ptr->weakcount_.load() > 1 || |
896 | (owning_weak_ptr->refcount_.load() == 0 && |
897 | owning_weak_ptr->weakcount_.load() > 0), |
898 | "weak_intrusive_ptr: Can only weak_intrusive_ptr::reclaim() owning pointers that were created using weak_intrusive_ptr::release()." ); |
899 | return weak_intrusive_ptr(owning_weak_ptr); |
900 | } |
901 | |
902 | /** |
903 | * Takes a pointer to TTarget* (may be weak or strong) and creates a |
904 | * new weak_intrusive_ptr representing a new weak reference, i.e. |
905 | * the raw pointer retains ownership. |
906 | */ |
907 | static weak_intrusive_ptr reclaim_copy(TTarget* owning_ptr) { |
908 | auto ret = reclaim(owning_ptr); |
909 | ret.retain_(); |
910 | return ret; |
911 | } |
912 | |
913 | template <class TTarget1, class NullType1, class TTarget2, class NullType2> |
914 | friend bool operator<( |
915 | const weak_intrusive_ptr<TTarget1, NullType1>& lhs, |
916 | const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept; |
917 | template <class TTarget1, class NullType1, class TTarget2, class NullType2> |
918 | friend bool operator==( |
919 | const weak_intrusive_ptr<TTarget1, NullType1>& lhs, |
920 | const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept; |
921 | }; |
922 | |
923 | template <class TTarget, class NullType> |
924 | inline void swap( |
925 | weak_intrusive_ptr<TTarget, NullType>& lhs, |
926 | weak_intrusive_ptr<TTarget, NullType>& rhs) noexcept { |
927 | lhs.swap(rhs); |
928 | } |
929 | |
930 | // To allow weak_intrusive_ptr inside std::map or std::set, we need operator< |
931 | template <class TTarget1, class NullType1, class TTarget2, class NullType2> |
932 | inline bool operator<( |
933 | const weak_intrusive_ptr<TTarget1, NullType1>& lhs, |
934 | const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept { |
935 | return lhs.target_ < rhs.target_; |
936 | } |
937 | |
938 | template <class TTarget1, class NullType1, class TTarget2, class NullType2> |
939 | inline bool operator==( |
940 | const weak_intrusive_ptr<TTarget1, NullType1>& lhs, |
941 | const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept { |
942 | return lhs.target_ == rhs.target_; |
943 | } |
944 | |
945 | template <class TTarget1, class NullType1, class TTarget2, class NullType2> |
946 | inline bool operator!=( |
947 | const weak_intrusive_ptr<TTarget1, NullType1>& lhs, |
948 | const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept { |
949 | return !operator==(lhs, rhs); |
950 | } |
951 | |
952 | // Alias for documentary purposes, to more easily distinguish |
953 | // weak raw intrusive pointers from intrusive pointers. |
954 | using weak_intrusive_ptr_target = intrusive_ptr_target; |
955 | |
956 | // This namespace provides some methods for working with |
957 | // raw pointers that subclass intrusive_ptr_target. They are not provided |
958 | // as methods on intrusive_ptr_target, because ideally you would not need these |
959 | // methods at all (use smart pointers), but if you are dealing with legacy code |
960 | // that still needs to pass around raw pointers, you may find these quite |
961 | // useful. |
962 | // |
963 | // An important usage note: some functions are only valid if you have a |
964 | // strong raw pointer to the object, while others are only valid if you |
965 | // have a weak raw pointer to the object. ONLY call intrusive_ptr namespace |
966 | // functions on strong pointers, and weak_intrusive_ptr namespace functions |
967 | // on weak pointers. If you mix it up, you may get an assert failure. |
968 | namespace raw { |
969 | |
970 | namespace intrusive_ptr { |
971 | |
972 | // WARNING: Unlike the reclaim() API, it is NOT valid to pass |
973 | // NullType::singleton to this function |
974 | inline void incref(intrusive_ptr_target* self) { |
975 | if (self) { |
976 | detail::atomic_refcount_increment(self->refcount_); |
977 | } |
978 | } |
979 | |
980 | // WARNING: Unlike the reclaim() API, it is NOT valid to pass |
981 | // NullType::singleton to this function |
982 | inline void decref(intrusive_ptr_target* self) { |
983 | // Let it die |
984 | c10::intrusive_ptr<intrusive_ptr_target>::reclaim(self); |
985 | // NB: Caller still has 'self' pointer, but it's now invalid. |
986 | // If you want more safety, used the actual c10::intrusive_ptr class |
987 | } |
988 | |
989 | template <typename T> |
990 | inline T* make_weak(T* self) { |
991 | // NB: 'this' is a strong pointer, but we return a weak pointer |
992 | auto ptr = c10::intrusive_ptr<T>::reclaim(self); |
993 | c10::weak_intrusive_ptr<T> wptr(ptr); |
994 | ptr.release(); |
995 | return wptr.release(); |
996 | } |
997 | |
998 | inline size_t use_count(intrusive_ptr_target* self) { |
999 | auto ptr = c10::intrusive_ptr<intrusive_ptr_target>::reclaim(self); |
1000 | auto r = ptr.use_count(); |
1001 | ptr.release(); |
1002 | return r; |
1003 | } |
1004 | |
1005 | } // namespace intrusive_ptr |
1006 | |
1007 | namespace weak_intrusive_ptr { |
1008 | |
1009 | inline void incref(weak_intrusive_ptr_target* self) { |
1010 | detail::atomic_weakcount_increment(self->weakcount_); |
1011 | } |
1012 | |
1013 | inline void decref(weak_intrusive_ptr_target* self) { |
1014 | // Let it die |
1015 | c10::weak_intrusive_ptr<intrusive_ptr_target>::reclaim(self); |
1016 | // NB: You still "have" the 'self' pointer, but it's now invalid. |
1017 | // If you want more safety, used the actual c10::weak_intrusive_ptr class |
1018 | } |
1019 | |
1020 | template <typename T> |
1021 | inline T* lock(T* self) { |
1022 | auto wptr = c10::weak_intrusive_ptr<T>::reclaim(self); |
1023 | auto ptr = wptr.lock(); |
1024 | wptr.release(); |
1025 | return ptr.release(); |
1026 | } |
1027 | |
1028 | // This gives the STRONG refcount of a WEAK pointer |
1029 | inline size_t use_count(weak_intrusive_ptr_target* self) { |
1030 | auto wptr = c10::weak_intrusive_ptr<intrusive_ptr_target>::reclaim(self); |
1031 | auto r = wptr.use_count(); |
1032 | wptr.release(); |
1033 | return r; |
1034 | } |
1035 | |
1036 | } // namespace weak_intrusive_ptr |
1037 | |
1038 | } // namespace raw |
1039 | |
1040 | } // namespace c10 |
1041 | |
1042 | namespace std { |
1043 | // To allow intrusive_ptr and weak_intrusive_ptr inside std::unordered_map or |
1044 | // std::unordered_set, we need std::hash |
1045 | template <class TTarget, class NullType> |
1046 | struct hash<c10::intrusive_ptr<TTarget, NullType>> { |
1047 | size_t operator()(const c10::intrusive_ptr<TTarget, NullType>& x) const { |
1048 | return std::hash<TTarget*>()(x.get()); |
1049 | } |
1050 | }; |
1051 | template <class TTarget, class NullType> |
1052 | struct hash<c10::weak_intrusive_ptr<TTarget, NullType>> { |
1053 | size_t operator()(const c10::weak_intrusive_ptr<TTarget, NullType>& x) const { |
1054 | return std::hash<TTarget*>()(x._unsafe_get_target()); |
1055 | } |
1056 | }; |
1057 | } // namespace std |
1058 | |