1#pragma once
2
3#include <c10/util/C++17.h>
4#include <c10/util/Exception.h>
5#include <c10/util/ExclusivelyOwned.h>
6#include <c10/util/MaybeOwned.h>
7#include <atomic>
8#include <climits>
9#include <memory>
10#include <stdexcept>
11
12namespace pybind11 {
13template <typename, typename...>
14class class_;
15}
16
17namespace c10 {
18class intrusive_ptr_target;
19namespace raw {
20namespace weak_intrusive_ptr {
21inline void incref(intrusive_ptr_target* self);
22}
23namespace intrusive_ptr {
24inline void incref(intrusive_ptr_target* self);
25}
26
27// constructor tag used by intrusive_ptr constructors
28struct DontIncreaseRefcount {};
29} // namespace raw
30/**
31 * intrusive_ptr<T> is an alternative to shared_ptr<T> that has better
32 * performance because it does the refcounting intrusively
33 * (i.e. in a member of the object itself).
34 * Your class T needs to inherit from intrusive_ptr_target to allow it to be
35 * used in an intrusive_ptr<T>. Your class's constructor should not allow
36 *`this` to escape to other threads or create an intrusive_ptr from `this`.
37 */
38
39// Note [Stack allocated intrusive_ptr_target safety]
40// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
41// A well known problem with std::enable_shared_from_this is that it
42// allows you to create a std::shared_ptr from a stack allocated object,
43// which is totally bogus because the object will die once you return
44// from the stack. In intrusive_ptr, we can detect that this has occurred,
45// because we set the refcount/weakcount of objects which inherit from
46// intrusive_ptr_target to zero, *unless* we can prove that the object
47// was dynamically allocated (e.g., via make_intrusive).
48//
49// Thus, whenever you transmute a T* into a intrusive_ptr<T>, we check
50// and make sure that the refcount isn't zero (or, a more subtle
51// test for weak_intrusive_ptr<T>, for which the refcount may validly
52// be zero, but the weak refcount better not be zero), because that
53// tells us if the object was allocated by us. If it wasn't, no
54// intrusive_ptr for you!
55
56class C10_API intrusive_ptr_target {
57 // Note [Weak references for intrusive refcounting]
58 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
59 // Here's the scheme:
60 //
61 // - refcount == number of strong references to the object
62 // weakcount == number of weak references to the object,
63 // plus one more if refcount > 0
64 // An invariant: refcount > 0 => weakcount > 0
65 //
66 // - c10::StorageImpl stays live as long as there are any strong
67 // or weak pointers to it (weakcount > 0, since strong
68 // references count as a +1 to weakcount)
69 //
70 // - finalizers are called and data_ptr is deallocated when refcount == 0
71 //
72 // - Once refcount == 0, it can never again be > 0 (the transition
73 // from > 0 to == 0 is monotonic)
74 //
75 // - When you access c10::StorageImpl via a weak pointer, you must
76 // atomically increment the use count, if it is greater than 0.
77 // If it is not, you must report that the storage is dead.
78 //
79 mutable std::atomic<size_t> refcount_;
80 mutable std::atomic<size_t> weakcount_;
81
82 template <typename T, typename NullType>
83 friend class intrusive_ptr;
84 friend inline void raw::intrusive_ptr::incref(intrusive_ptr_target* self);
85
86 template <typename T, typename NullType>
87 friend class weak_intrusive_ptr;
88 friend inline void raw::weak_intrusive_ptr::incref(
89 intrusive_ptr_target* self);
90
91 template <typename T>
92 friend struct ExclusivelyOwnedTensorTraits;
93
94 protected:
95 // protected destructor. We never want to destruct intrusive_ptr_target*
96 // directly.
97 virtual ~intrusive_ptr_target() {
98// Disable -Wterminate and -Wexceptions so we're allowed to use assertions
99// (i.e. throw exceptions) in a destructor.
100// We also have to disable -Wunknown-warning-option and -Wpragmas, because
101// some other compilers don't know about -Wterminate or -Wexceptions and
102// will show a warning about unknown warning options otherwise.
103#if defined(_MSC_VER) && !defined(__clang__)
104#pragma warning(push)
105#pragma warning( \
106 disable : 4297) // function assumed not to throw an exception but does
107#else
108#pragma GCC diagnostic push
109#pragma GCC diagnostic ignored "-Wpragmas"
110#pragma GCC diagnostic ignored "-Wunknown-warning-option"
111#pragma GCC diagnostic ignored "-Wterminate"
112#pragma GCC diagnostic ignored "-Wexceptions"
113#endif
114 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
115 // Second condition is there to accommodate
116 // unsafe_adapt_non_heap_allocated: since we are doing our own
117 // deallocation in that case, it is correct for each
118 // expected_decref to have happened (some user code tried to
119 // decref and thus free the object, but it didn't happen right
120 // away) or not (no user code tried to free the object, and
121 // now it's getting destroyed through whatever mechanism the
122 // caller of unsafe_adapt_non_heap_allocated wanted to
123 // use). We choose our reference count such that the count
124 // will not dip below INT_MAX regardless.
125 refcount_.load() == 0 || refcount_.load() >= INT_MAX,
126 "Tried to destruct an intrusive_ptr_target that still has intrusive_ptr to it; refcount was ",
127 refcount_.load());
128 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
129 // See ~intrusive_ptr for optimization that will frequently result in 1
130 // at destruction time.
131 weakcount_.load() == 1 || weakcount_.load() == 0 ||
132 weakcount_.load() == INT_MAX - 1 || weakcount_.load() == INT_MAX,
133 "Tried to destruct an intrusive_ptr_target that still has weak_intrusive_ptr to it");
134#if defined(_MSC_VER) && !defined(__clang__)
135#pragma warning(pop)
136#else
137#pragma GCC diagnostic pop
138#endif
139 }
140
141 constexpr intrusive_ptr_target() noexcept : refcount_(0), weakcount_(0) {}
142
143 // intrusive_ptr_target supports copy and move: but refcount and weakcount
144 // don't participate (since they are intrinsic properties of the memory
145 // location)
146 intrusive_ptr_target(intrusive_ptr_target&& /*other*/) noexcept
147 : intrusive_ptr_target() {}
148
149 intrusive_ptr_target& operator=(intrusive_ptr_target&& /*other*/) noexcept {
150 return *this;
151 }
152
153 intrusive_ptr_target(const intrusive_ptr_target& /*other*/) noexcept
154 : intrusive_ptr_target() {}
155
156 intrusive_ptr_target& operator=(
157 const intrusive_ptr_target& /*other*/) noexcept {
158 return *this;
159 }
160
161 private:
162 /**
163 * This is called when refcount reaches zero.
164 * You can override this to release expensive resources.
165 * There might still be weak references, so your object might not get
166 * destructed yet, but you can assume the object isn't used anymore,
167 * i.e. no more calls to methods or accesses to members (we just can't
168 * destruct it yet because we need the weakcount accessible).
169 *
170 * If there are no weak references (i.e. your class is about to be
171 * destructed), this function WILL NOT be called.
172 */
173 virtual void release_resources() {}
174};
175
176namespace detail {
177template <class TTarget>
178struct intrusive_target_default_null_type final {
179 static constexpr TTarget* singleton() noexcept {
180 return nullptr;
181 }
182};
183
184template <class TTarget, class ToNullType, class FromNullType>
185TTarget* assign_ptr_(TTarget* rhs) {
186 if (FromNullType::singleton() == rhs) {
187 return ToNullType::singleton();
188 } else {
189 return rhs;
190 }
191}
192
193// Increment needs to be acquire-release to make use_count() and
194// unique() reliable.
195inline size_t atomic_refcount_increment(std::atomic<size_t>& refcount) {
196 return refcount.fetch_add(1, std::memory_order_acq_rel) + 1;
197}
198
199// weak_use_count() is only used for testing, so we don't need it to
200// be reliable. Relaxed should be fine.
201inline size_t atomic_weakcount_increment(std::atomic<size_t>& weakcount) {
202 return weakcount.fetch_add(1, std::memory_order_relaxed) + 1;
203}
204
205// Both decrements need to be acquire-release for correctness. See
206// e.g. std::shared_ptr implementation.
207inline size_t atomic_refcount_decrement(std::atomic<size_t>& refcount) {
208 return refcount.fetch_sub(1, std::memory_order_acq_rel) - 1;
209}
210
211inline size_t atomic_weakcount_decrement(std::atomic<size_t>& weakcount) {
212 return weakcount.fetch_sub(1, std::memory_order_acq_rel) - 1;
213}
214
215} // namespace detail
216
217template <class TTarget, class NullType>
218class weak_intrusive_ptr;
219
220template <
221 class TTarget,
222 class NullType = detail::intrusive_target_default_null_type<TTarget>>
223class intrusive_ptr final {
224 private:
225// the following static assert would be nice to have but it requires
226// the target class T to be fully defined when intrusive_ptr<T> is instantiated
227// this is a problem for classes that contain pointers to themselves
228// static_assert(
229// std::is_base_of<intrusive_ptr_target, TTarget>::value,
230// "intrusive_ptr can only be used for classes that inherit from
231// intrusive_ptr_target.");
232#ifndef _WIN32
233 // This static_assert triggers on MSVC
234 // error C2131: expression did not evaluate to a constant
235 static_assert(
236 NullType::singleton() == NullType::singleton(),
237 "NullType must have a constexpr singleton() method");
238#endif
239 static_assert(
240 std::is_base_of<
241 TTarget,
242 typename std::remove_pointer<decltype(NullType::singleton())>::type>::
243 value,
244 "NullType::singleton() must return a element_type* pointer");
245
246 TTarget* target_;
247
248 template <typename T>
249 friend struct ExclusivelyOwnedTensorTraits;
250 template <class TTarget2, class NullType2>
251 friend class intrusive_ptr;
252 friend class weak_intrusive_ptr<TTarget, NullType>;
253
254 // Make pybind11::class_ be a friend class of intrusive_ptr, so that custom
255 // smart holder in pybind11 could access the private constructor of
256 // intrusive_ptr(T*) which took the ownership of the object. This is required
257 // by customer holder macro PYBIND11_DECLARE_HOLDER_TYPE, where it uses
258 // intrusive_ptr(TTarget*) to initialize and take ownership of the object. For
259 // details, see
260 // https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#custom-smart-pointers
261 template <typename, typename...>
262 friend class pybind11::class_;
263
264 void retain_() {
265 if (target_ != NullType::singleton()) {
266 size_t new_refcount =
267 detail::atomic_refcount_increment(target_->refcount_);
268 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
269 new_refcount != 1,
270 "intrusive_ptr: Cannot increase refcount after it reached zero.");
271 }
272 }
273
274 void reset_() noexcept {
275 if (target_ != NullType::singleton() &&
276 detail::atomic_refcount_decrement(target_->refcount_) == 0) {
277 // See comment above about weakcount. As long as refcount>0,
278 // weakcount is one larger than the actual number of weak references.
279 // So we need to decrement it here.
280 bool should_delete =
281 target_->weakcount_.load(std::memory_order_acquire) == 1;
282 if (!should_delete) {
283 // justification for const_cast: release_resources is basically a
284 // destructor and a destructor always mutates the object, even for const
285 // objects. NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
286 const_cast<std::remove_const_t<TTarget>*>(target_)->release_resources();
287 should_delete =
288 detail::atomic_weakcount_decrement(target_->weakcount_) == 0;
289 }
290 if (should_delete) {
291 delete target_;
292 }
293 }
294 }
295
296 // raw pointer constructors are not public because we shouldn't make
297 // intrusive_ptr out of raw pointers except from inside the make_intrusive(),
298 // reclaim() and weak_intrusive_ptr::lock() implementations.
299
300 // This constructor will increase the ref counter for you.
301 // This constructor will be used by the make_intrusive(), and also pybind11,
302 // which wrap the intrusive_ptr holder around the raw pointer and incref
303 // correspondingly (pybind11 requires raw pointer constructor to incref by
304 // default).
305 explicit intrusive_ptr(TTarget* target)
306 : intrusive_ptr(target, raw::DontIncreaseRefcount{}) {
307 if (target_ != NullType::singleton()) {
308 // We just created result.target_, so we know no other thread has
309 // access to it, so we know we needn't care about memory ordering.
310 // (On x86_64, a store with memory_order_relaxed generates a plain old
311 // `mov`, whereas an atomic increment does a lock-prefixed `add`, which is
312 // much more expensive: https://godbolt.org/z/eKPzj8.)
313 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
314 target_->refcount_ == 0 && target_->weakcount_ == 0,
315 "intrusive_ptr: Newly-created target had non-zero refcounts. Does its "
316 "constructor do something strange like incref or create an "
317 "intrusive_ptr from `this`?");
318 target_->refcount_.store(1, std::memory_order_relaxed);
319 target_->weakcount_.store(1, std::memory_order_relaxed);
320 }
321 }
322
323 public:
324 using element_type = TTarget;
325
326 intrusive_ptr() noexcept
327 : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {}
328
329 intrusive_ptr(std::nullptr_t) noexcept
330 : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {}
331
332 // This constructor will not increase the ref counter for you.
333 // We use the tagged dispatch mechanism to explicitly mark this constructor
334 // to not increase the refcount
335 explicit intrusive_ptr(TTarget* target, raw::DontIncreaseRefcount) noexcept
336 : target_(target) {}
337
338 explicit intrusive_ptr(std::unique_ptr<TTarget> rhs) noexcept
339 : intrusive_ptr(rhs.release()) {}
340
341 intrusive_ptr(intrusive_ptr&& rhs) noexcept : target_(rhs.target_) {
342 rhs.target_ = NullType::singleton();
343 }
344
345 template <class From, class FromNullType>
346 /* implicit */ intrusive_ptr(intrusive_ptr<From, FromNullType>&& rhs) noexcept
347 : target_(
348 detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) {
349 static_assert(
350 std::is_convertible<From*, TTarget*>::value,
351 "Type mismatch. intrusive_ptr move constructor got pointer of wrong type.");
352 rhs.target_ = FromNullType::singleton();
353 }
354
355 intrusive_ptr(const intrusive_ptr& rhs) : target_(rhs.target_) {
356 retain_();
357 }
358
359 template <class From, class FromNullType>
360 /* implicit */ intrusive_ptr(const intrusive_ptr<From, FromNullType>& rhs)
361 : target_(
362 detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) {
363 static_assert(
364 std::is_convertible<From*, TTarget*>::value,
365 "Type mismatch. intrusive_ptr copy constructor got pointer of wrong type.");
366 retain_();
367 }
368
369 ~intrusive_ptr() noexcept {
370 reset_();
371 }
372
373 intrusive_ptr& operator=(intrusive_ptr&& rhs) & noexcept {
374 return operator=<TTarget, NullType>(std::move(rhs));
375 }
376
377 template <class From, class FromNullType>
378 intrusive_ptr& operator=(intrusive_ptr<From, FromNullType>&& rhs) & noexcept {
379 static_assert(
380 std::is_convertible<From*, TTarget*>::value,
381 "Type mismatch. intrusive_ptr move assignment got pointer of wrong type.");
382 intrusive_ptr tmp = std::move(rhs);
383 swap(tmp);
384 return *this;
385 }
386
387 intrusive_ptr& operator=(const intrusive_ptr& rhs) & noexcept {
388 return operator=<TTarget, NullType>(rhs);
389 }
390
391 template <class From, class FromNullType>
392 intrusive_ptr& operator=(const intrusive_ptr<From, NullType>& rhs) & {
393 static_assert(
394 std::is_convertible<From*, TTarget*>::value,
395 "Type mismatch. intrusive_ptr copy assignment got pointer of wrong type.");
396 intrusive_ptr tmp = rhs;
397 swap(tmp);
398 return *this;
399 }
400
401 TTarget* get() const noexcept {
402 return target_;
403 }
404
405 TTarget& operator*() const noexcept {
406 return *target_;
407 }
408
409 TTarget* operator->() const noexcept {
410 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDelete)
411 return target_;
412 }
413
414 operator bool() const noexcept {
415 return target_ != NullType::singleton();
416 }
417
418 void reset() noexcept {
419 reset_();
420 target_ = NullType::singleton();
421 }
422
423 void swap(intrusive_ptr& rhs) noexcept {
424 TTarget* tmp = target_;
425 target_ = rhs.target_;
426 rhs.target_ = tmp;
427 }
428
429 // We do a lot of null-pointer checks in our code, good to have this be cheap.
430 bool defined() const noexcept {
431 return target_ != NullType::singleton();
432 }
433
434 size_t use_count() const noexcept {
435 if (target_ == NullType::singleton()) {
436 return 0;
437 }
438 return target_->refcount_.load(std::memory_order_acquire);
439 }
440
441 size_t weak_use_count() const noexcept {
442 if (target_ == NullType::singleton()) {
443 return 0;
444 }
445 return target_->weakcount_.load(std::memory_order_acquire);
446 }
447
448 bool unique() const noexcept {
449 return use_count() == 1;
450 }
451
452 /**
453 * Returns an owning (!) pointer to the underlying object and makes the
454 * intrusive_ptr instance invalid. That means the refcount is not decreased.
455 * You *must* put the returned pointer back into a intrusive_ptr using
456 * intrusive_ptr::reclaim(ptr) to properly destruct it.
457 * This is helpful for C APIs.
458 */
459 TTarget* release() noexcept {
460 // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
461 TTarget* result = target_;
462 target_ = NullType::singleton();
463 return result;
464 }
465
466 /**
467 * Takes an owning pointer to TTarget* and creates an intrusive_ptr that takes
468 * over ownership. That means the refcount is not increased.
469 * This is the counter-part to intrusive_ptr::release() and the pointer
470 * passed in *must* have been created using intrusive_ptr::release().
471 */
472 static intrusive_ptr reclaim(TTarget* owning_ptr) {
473 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
474 owning_ptr == NullType::singleton() ||
475 owning_ptr->refcount_.load() == 0 || owning_ptr->weakcount_.load(),
476 "TTarget violates the invariant that refcount > 0 => weakcount > 0");
477 return intrusive_ptr(owning_ptr, raw::DontIncreaseRefcount{});
478 }
479
480 /**
481 * Takes an owning pointer to TTarget* and creates an intrusive_ptr
482 * representing a new reference, i.e. the raw pointer retains
483 * ownership.
484 */
485 static intrusive_ptr reclaim_copy(TTarget* owning_ptr) {
486 auto ret = reclaim(owning_ptr);
487 ret.retain_();
488 return ret;
489 }
490
491 /**
492 * Allocate a heap object with args and wrap it inside a intrusive_ptr and
493 * incref. This is a helper function to let make_intrusive() access private
494 * intrusive_ptr constructors.
495 */
496 template <class... Args>
497 static intrusive_ptr make(Args&&... args) {
498 return intrusive_ptr(new TTarget(std::forward<Args>(args)...));
499 }
500
501 /**
502 * Turn a new instance of TTarget (e.g., literally allocated
503 * using new TTarget(...) into an intrusive_ptr. If possible,
504 * use intrusive_ptr::make instead which statically guarantees
505 * that the allocation was done properly.
506 *
507 * At the moment, the only reason this method exists is because
508 * pybind11 holder types expect to be able to allocate in
509 * this way (because pybind11 handles the new allocation itself).
510 */
511 static intrusive_ptr unsafe_steal_from_new(TTarget* raw_ptr) {
512 return intrusive_ptr(raw_ptr);
513 }
514
515 /**
516 * Turn an instance of TTarget that should not be reference counted
517 * (e.g., allocated into an arena with placement new) into an
518 * intrusive_ptr. This is gratuitously unsafe and should only be
519 * used if you can guarantee that the pointer will not escape and be
520 * refcounted as normal.
521 *
522 * `expected_decrefs` is a debugging parameter: it indicates the
523 * number of strong owners the intrusive_ptr_target in question is
524 * expected to get. In most use cases, this will likely be 1.
525 *
526 * The reason this method exists is for manually sharing
527 * StorageImpls across Tensors in the static runtime. It needs
528 * access to private intrusive_ptr members so that the refcounts can
529 * be initialized to custom values.
530 */
531 static intrusive_ptr unsafe_adapt_non_heap_allocated(
532 TTarget* raw_ptr,
533 size_t expected_decrefs) {
534 intrusive_ptr result(raw_ptr, raw::DontIncreaseRefcount{});
535 // INT_MAX is impractically huge for a reference count, while
536 // being in no danger of overflowing size_t. We actually only need to
537 // initialize the refcount to 2 -- we are just doing an unbalanced
538 // incref to prevent the non-heap-allocated target from being
539 // freed, and we are optimizing that incref by directly
540 // initializing the refcounts rather than doing an expensive
541 // atomic increment. The reason to use INT_MAX is to accommodate
542 // the debug assertions in ~intrusive_ptr_target.
543#ifdef NDEBUG
544 expected_decrefs = 0;
545#endif
546 result.target_->refcount_.store(
547 INT_MAX + expected_decrefs, std::memory_order_relaxed);
548 result.target_->weakcount_.store(INT_MAX, std::memory_order_relaxed);
549 return result;
550 }
551
552 /**
553 * Turn a **non-owning raw pointer** to an intrusive_ptr. It is
554 * the moral equivalent of enable_shared_from_this on a shared pointer.
555 *
556 * This method is only valid for objects that are already live. If
557 * you are looking for the moral equivalent of unique_ptr<T>(T*)
558 * constructor, see steal_from_new.
559 *
560 * TODO: https://github.com/pytorch/pytorch/issues/56482
561 */
562 static intrusive_ptr unsafe_reclaim_from_nonowning(TTarget* raw_ptr) {
563 // See Note [Stack allocated intrusive_ptr_target safety]
564 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
565 raw_ptr == NullType::singleton() || raw_ptr->refcount_.load() > 0,
566 "intrusive_ptr: Can only reclaim pointers that are owned by someone");
567 auto ptr = reclaim(raw_ptr); // doesn't increase refcount
568 ptr.retain_();
569 return ptr;
570 }
571};
572
573template <
574 class TTarget,
575 class NullType = detail::intrusive_target_default_null_type<TTarget>,
576 class... Args>
577inline intrusive_ptr<TTarget, NullType> make_intrusive(Args&&... args) {
578 return intrusive_ptr<TTarget, NullType>::make(std::forward<Args>(args)...);
579}
580
581template <class TTarget, class NullType>
582inline void swap(
583 intrusive_ptr<TTarget, NullType>& lhs,
584 intrusive_ptr<TTarget, NullType>& rhs) noexcept {
585 lhs.swap(rhs);
586}
587
588// To allow intrusive_ptr inside std::map or std::set, we need operator<
589template <class TTarget1, class NullType1, class TTarget2, class NullType2>
590inline bool operator<(
591 const intrusive_ptr<TTarget1, NullType1>& lhs,
592 const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
593 return lhs.get() < rhs.get();
594}
595
596template <class TTarget1, class NullType1, class TTarget2, class NullType2>
597inline bool operator==(
598 const intrusive_ptr<TTarget1, NullType1>& lhs,
599 const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
600 return lhs.get() == rhs.get();
601}
602
603template <class TTarget1, class NullType1>
604inline bool operator==(
605 const intrusive_ptr<TTarget1, NullType1>& lhs,
606 std::nullptr_t) noexcept {
607 return lhs.get() == nullptr;
608}
609
610template <class TTarget2, class NullType2>
611inline bool operator==(
612 std::nullptr_t,
613 const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
614 return nullptr == rhs.get();
615}
616
617template <class TTarget1, class NullType1, class TTarget2, class NullType2>
618inline bool operator!=(
619 const intrusive_ptr<TTarget1, NullType1>& lhs,
620 const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
621 return !operator==(lhs, rhs);
622}
623
624template <class TTarget1, class NullType1>
625inline bool operator!=(
626 const intrusive_ptr<TTarget1, NullType1>& lhs,
627 std::nullptr_t) noexcept {
628 return !operator==(lhs, nullptr);
629}
630
631template <class TTarget2, class NullType2>
632inline bool operator!=(
633 std::nullptr_t,
634 const intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
635 return !operator==(nullptr, rhs);
636}
637template <typename T>
638struct MaybeOwnedTraits<c10::intrusive_ptr<T>> {
639 using owned_type = c10::intrusive_ptr<T>;
640 using borrow_type = c10::intrusive_ptr<T>;
641
642 static borrow_type createBorrow(const owned_type& from) {
643 return borrow_type::reclaim(from.get());
644 }
645
646 static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
647 lhs.release();
648 lhs = borrow_type::reclaim(rhs.get());
649 }
650
651 static void destroyBorrow(borrow_type& toDestroy) {
652 toDestroy.release();
653 }
654
655 static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
656 return borrow;
657 }
658
659 static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
660 return &borrow;
661 }
662
663 static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
664 return true;
665 }
666};
667
668template <
669 typename TTarget,
670 class NullType = detail::intrusive_target_default_null_type<TTarget>>
671class weak_intrusive_ptr final {
672 private:
673 static_assert(
674 std::is_base_of<intrusive_ptr_target, TTarget>::value,
675 "intrusive_ptr can only be used for classes that inherit from intrusive_ptr_target.");
676#ifndef _WIN32
677 // This static_assert triggers on MSVC
678 // error C2131: expression did not evaluate to a constant
679 static_assert(
680 NullType::singleton() == NullType::singleton(),
681 "NullType must have a constexpr singleton() method");
682#endif
683 static_assert(
684 std::is_base_of<
685 TTarget,
686 typename std::remove_pointer<decltype(NullType::singleton())>::type>::
687 value,
688 "NullType::singleton() must return a element_type* pointer");
689
690 TTarget* target_;
691
692 template <class TTarget2, class NullType2>
693 friend class weak_intrusive_ptr;
694
695 void retain_() {
696 if (target_ != NullType::singleton()) {
697 size_t new_weakcount =
698 detail::atomic_weakcount_increment(target_->weakcount_);
699 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
700 new_weakcount != 1,
701 "weak_intrusive_ptr: Cannot increase weakcount after it reached zero.");
702 }
703 }
704
705 void reset_() noexcept {
706 if (target_ != NullType::singleton() &&
707 detail::atomic_weakcount_decrement(target_->weakcount_) == 0) {
708 // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDelete)
709 delete target_;
710 }
711 target_ = NullType::singleton();
712 }
713
714 constexpr explicit weak_intrusive_ptr(TTarget* target) : target_(target) {}
715
716 public:
717 using element_type = TTarget;
718
719 explicit weak_intrusive_ptr(const intrusive_ptr<TTarget, NullType>& ptr)
720 : weak_intrusive_ptr(ptr.get()) {
721 retain_();
722 }
723
724 weak_intrusive_ptr(weak_intrusive_ptr&& rhs) noexcept : target_(rhs.target_) {
725 rhs.target_ = NullType::singleton();
726 }
727
728 template <class From, class FromNullType>
729 /* implicit */ weak_intrusive_ptr(
730 weak_intrusive_ptr<From, FromNullType>&& rhs) noexcept
731 : target_(
732 detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) {
733 static_assert(
734 std::is_convertible<From*, TTarget*>::value,
735 "Type mismatch. weak_intrusive_ptr move constructor got pointer of wrong type.");
736 rhs.target_ = FromNullType::singleton();
737 }
738
739 weak_intrusive_ptr(const weak_intrusive_ptr& rhs) : target_(rhs.target_) {
740 retain_();
741 }
742
743 template <class From, class FromNullType>
744 /* implicit */ weak_intrusive_ptr(
745 const weak_intrusive_ptr<From, FromNullType>& rhs)
746 : target_(
747 detail::assign_ptr_<TTarget, NullType, FromNullType>(rhs.target_)) {
748 static_assert(
749 std::is_convertible<From*, TTarget*>::value,
750 "Type mismatch. weak_intrusive_ptr copy constructor got pointer of wrong type.");
751 retain_();
752 }
753
754 ~weak_intrusive_ptr() noexcept {
755 reset_();
756 }
757
758 weak_intrusive_ptr& operator=(weak_intrusive_ptr&& rhs) & noexcept {
759 return operator=<TTarget, NullType>(std::move(rhs));
760 }
761
762 template <class From, class FromNullType>
763 weak_intrusive_ptr& operator=(
764 weak_intrusive_ptr<From, FromNullType>&& rhs) & noexcept {
765 static_assert(
766 std::is_convertible<From*, TTarget*>::value,
767 "Type mismatch. weak_intrusive_ptr move assignment got pointer of wrong type.");
768 weak_intrusive_ptr tmp = std::move(rhs);
769 swap(tmp);
770 return *this;
771 }
772
773 weak_intrusive_ptr& operator=(const weak_intrusive_ptr& rhs) & noexcept {
774 return operator=<TTarget, NullType>(rhs);
775 }
776
777 weak_intrusive_ptr& operator=(
778 const intrusive_ptr<TTarget, NullType>& rhs) & noexcept {
779 weak_intrusive_ptr tmp(rhs);
780 swap(tmp);
781 return *this;
782 }
783
784 template <class From, class FromNullType>
785 weak_intrusive_ptr& operator=(
786 const weak_intrusive_ptr<From, NullType>& rhs) & {
787 static_assert(
788 std::is_convertible<From*, TTarget*>::value,
789 "Type mismatch. weak_intrusive_ptr copy assignment got pointer of wrong type.");
790 weak_intrusive_ptr tmp = rhs;
791 swap(tmp);
792 return *this;
793 }
794
795 void reset() noexcept {
796 reset_();
797 }
798
799 void swap(weak_intrusive_ptr& rhs) noexcept {
800 TTarget* tmp = target_;
801 target_ = rhs.target_;
802 rhs.target_ = tmp;
803 }
804
805 // NB: This should ONLY be used by the std::hash implementation
806 // for weak_intrusive_ptr. Another way you could do this is
807 // friend std::hash<weak_intrusive_ptr>, but this triggers two
808 // bugs:
809 //
810 // (1) It triggers an nvcc bug, where std::hash in a friend class
811 // declaration gets preprocessed into hash, which then cannot
812 // actually be found. The error in this case looks like:
813 //
814 // error: no template named 'hash'; did you mean 'std::hash'?
815 //
816 // (2) On OS X, std::hash is declared as a struct, not a class.
817 // This twings:
818 //
819 // error: class 'hash' was previously declared as a struct
820 // [-Werror,-Wmismatched-tags]
821 //
822 // Both of these are work-aroundable, but on the whole, I decided
823 // it would be simpler and easier to make work if we just expose
824 // an unsafe getter for target_
825 //
826 TTarget* _unsafe_get_target() const noexcept {
827 return target_;
828 }
829
830 size_t use_count() const noexcept {
831 if (target_ == NullType::singleton()) {
832 return 0;
833 }
834 return target_->refcount_.load(
835 std::memory_order_acquire); // refcount, not weakcount!
836 }
837
838 size_t weak_use_count() const noexcept {
839 if (target_ == NullType::singleton()) {
840 return 0;
841 }
842 return target_->weakcount_.load(std::memory_order_acquire);
843 }
844
845 bool expired() const noexcept {
846 return use_count() == 0;
847 }
848
849 intrusive_ptr<TTarget, NullType> lock() const noexcept {
850 if (expired()) {
851 return intrusive_ptr<TTarget, NullType>();
852 } else {
853 auto refcount = target_->refcount_.load(std::memory_order_seq_cst);
854 do {
855 if (refcount == 0) {
856 // Object already destructed, no strong references left anymore.
857 // Return nullptr.
858 return intrusive_ptr<TTarget, NullType>();
859 }
860 } while (
861 !target_->refcount_.compare_exchange_weak(refcount, refcount + 1));
862 return intrusive_ptr<TTarget, NullType>(
863 target_, raw::DontIncreaseRefcount{});
864 }
865 }
866
867 /**
868 * Returns an owning (but still only weakly referenced) pointer to the
869 * underlying object and makes the weak_intrusive_ptr instance invalid.
870 * That means the weakcount is not decreased.
871 * You *must* put the returned pointer back into a weak_intrusive_ptr using
872 * weak_intrusive_ptr::reclaim(ptr) to properly destruct it.
873 * This is helpful for C APIs.
874 */
875 TTarget* release() noexcept {
876 TTarget* result = target_;
877 target_ = NullType::singleton();
878 return result;
879 }
880
881 /**
882 * Takes an owning (but must be weakly referenced) pointer to TTarget* and
883 * creates a weak_intrusive_ptr that takes over ownership.
884 * This means that the weakcount is not increased.
885 * This is the counter-part to weak_intrusive_ptr::release() and the pointer
886 * passed in *must* have been created using weak_intrusive_ptr::release().
887 */
888 static weak_intrusive_ptr reclaim(TTarget* owning_weak_ptr) {
889 // See Note [Stack allocated intrusive_ptr_target safety]
890 // if refcount > 0, weakcount must be >1 for weak references to exist.
891 // see weak counting explanation at top of this file.
892 // if refcount == 0, weakcount only must be >0.
893 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
894 owning_weak_ptr == NullType::singleton() ||
895 owning_weak_ptr->weakcount_.load() > 1 ||
896 (owning_weak_ptr->refcount_.load() == 0 &&
897 owning_weak_ptr->weakcount_.load() > 0),
898 "weak_intrusive_ptr: Can only weak_intrusive_ptr::reclaim() owning pointers that were created using weak_intrusive_ptr::release().");
899 return weak_intrusive_ptr(owning_weak_ptr);
900 }
901
902 /**
903 * Takes a pointer to TTarget* (may be weak or strong) and creates a
904 * new weak_intrusive_ptr representing a new weak reference, i.e.
905 * the raw pointer retains ownership.
906 */
907 static weak_intrusive_ptr reclaim_copy(TTarget* owning_ptr) {
908 auto ret = reclaim(owning_ptr);
909 ret.retain_();
910 return ret;
911 }
912
913 template <class TTarget1, class NullType1, class TTarget2, class NullType2>
914 friend bool operator<(
915 const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
916 const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept;
917 template <class TTarget1, class NullType1, class TTarget2, class NullType2>
918 friend bool operator==(
919 const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
920 const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept;
921};
922
923template <class TTarget, class NullType>
924inline void swap(
925 weak_intrusive_ptr<TTarget, NullType>& lhs,
926 weak_intrusive_ptr<TTarget, NullType>& rhs) noexcept {
927 lhs.swap(rhs);
928}
929
930// To allow weak_intrusive_ptr inside std::map or std::set, we need operator<
931template <class TTarget1, class NullType1, class TTarget2, class NullType2>
932inline bool operator<(
933 const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
934 const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
935 return lhs.target_ < rhs.target_;
936}
937
938template <class TTarget1, class NullType1, class TTarget2, class NullType2>
939inline bool operator==(
940 const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
941 const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
942 return lhs.target_ == rhs.target_;
943}
944
945template <class TTarget1, class NullType1, class TTarget2, class NullType2>
946inline bool operator!=(
947 const weak_intrusive_ptr<TTarget1, NullType1>& lhs,
948 const weak_intrusive_ptr<TTarget2, NullType2>& rhs) noexcept {
949 return !operator==(lhs, rhs);
950}
951
952// Alias for documentary purposes, to more easily distinguish
953// weak raw intrusive pointers from intrusive pointers.
954using weak_intrusive_ptr_target = intrusive_ptr_target;
955
956// This namespace provides some methods for working with
957// raw pointers that subclass intrusive_ptr_target. They are not provided
958// as methods on intrusive_ptr_target, because ideally you would not need these
959// methods at all (use smart pointers), but if you are dealing with legacy code
960// that still needs to pass around raw pointers, you may find these quite
961// useful.
962//
963// An important usage note: some functions are only valid if you have a
964// strong raw pointer to the object, while others are only valid if you
965// have a weak raw pointer to the object. ONLY call intrusive_ptr namespace
966// functions on strong pointers, and weak_intrusive_ptr namespace functions
967// on weak pointers. If you mix it up, you may get an assert failure.
968namespace raw {
969
970namespace intrusive_ptr {
971
972// WARNING: Unlike the reclaim() API, it is NOT valid to pass
973// NullType::singleton to this function
974inline void incref(intrusive_ptr_target* self) {
975 if (self) {
976 detail::atomic_refcount_increment(self->refcount_);
977 }
978}
979
980// WARNING: Unlike the reclaim() API, it is NOT valid to pass
981// NullType::singleton to this function
982inline void decref(intrusive_ptr_target* self) {
983 // Let it die
984 c10::intrusive_ptr<intrusive_ptr_target>::reclaim(self);
985 // NB: Caller still has 'self' pointer, but it's now invalid.
986 // If you want more safety, used the actual c10::intrusive_ptr class
987}
988
989template <typename T>
990inline T* make_weak(T* self) {
991 // NB: 'this' is a strong pointer, but we return a weak pointer
992 auto ptr = c10::intrusive_ptr<T>::reclaim(self);
993 c10::weak_intrusive_ptr<T> wptr(ptr);
994 ptr.release();
995 return wptr.release();
996}
997
998inline size_t use_count(intrusive_ptr_target* self) {
999 auto ptr = c10::intrusive_ptr<intrusive_ptr_target>::reclaim(self);
1000 auto r = ptr.use_count();
1001 ptr.release();
1002 return r;
1003}
1004
1005} // namespace intrusive_ptr
1006
1007namespace weak_intrusive_ptr {
1008
1009inline void incref(weak_intrusive_ptr_target* self) {
1010 detail::atomic_weakcount_increment(self->weakcount_);
1011}
1012
1013inline void decref(weak_intrusive_ptr_target* self) {
1014 // Let it die
1015 c10::weak_intrusive_ptr<intrusive_ptr_target>::reclaim(self);
1016 // NB: You still "have" the 'self' pointer, but it's now invalid.
1017 // If you want more safety, used the actual c10::weak_intrusive_ptr class
1018}
1019
1020template <typename T>
1021inline T* lock(T* self) {
1022 auto wptr = c10::weak_intrusive_ptr<T>::reclaim(self);
1023 auto ptr = wptr.lock();
1024 wptr.release();
1025 return ptr.release();
1026}
1027
1028// This gives the STRONG refcount of a WEAK pointer
1029inline size_t use_count(weak_intrusive_ptr_target* self) {
1030 auto wptr = c10::weak_intrusive_ptr<intrusive_ptr_target>::reclaim(self);
1031 auto r = wptr.use_count();
1032 wptr.release();
1033 return r;
1034}
1035
1036} // namespace weak_intrusive_ptr
1037
1038} // namespace raw
1039
1040} // namespace c10
1041
1042namespace std {
1043// To allow intrusive_ptr and weak_intrusive_ptr inside std::unordered_map or
1044// std::unordered_set, we need std::hash
1045template <class TTarget, class NullType>
1046struct hash<c10::intrusive_ptr<TTarget, NullType>> {
1047 size_t operator()(const c10::intrusive_ptr<TTarget, NullType>& x) const {
1048 return std::hash<TTarget*>()(x.get());
1049 }
1050};
1051template <class TTarget, class NullType>
1052struct hash<c10::weak_intrusive_ptr<TTarget, NullType>> {
1053 size_t operator()(const c10::weak_intrusive_ptr<TTarget, NullType>& x) const {
1054 return std::hash<TTarget*>()(x._unsafe_get_target());
1055 }
1056};
1057} // namespace std
1058