1 | #pragma once |
2 | |
3 | #include <c10/core/TensorImpl.h> |
4 | |
5 | #include <utility> |
6 | |
7 | namespace c10 { |
8 | // Shared ExclusivelyOwnedTraits implementation between caffe2::Tensor and |
9 | // at::TensorBase. |
10 | template <typename TensorType> |
11 | struct ExclusivelyOwnedTensorTraits { |
12 | using repr_type = TensorType; |
13 | using pointer_type = TensorType*; |
14 | using const_pointer_type = const TensorType*; |
15 | |
16 | static repr_type nullRepr() { |
17 | return TensorType(); |
18 | } |
19 | |
20 | template <class... Args> |
21 | static repr_type createInPlace(Args&&... args) { |
22 | return TensorType(std::forward<Args>(args)...); |
23 | } |
24 | |
25 | static repr_type moveToRepr(TensorType&& x) { |
26 | return std::move(x); |
27 | } |
28 | |
29 | static void destroyOwned(TensorType& x) { |
30 | TensorImpl* const toDestroy = x.unsafeReleaseTensorImpl(); |
31 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
32 | toDestroy != nullptr, "Tensor somehow got null TensorImpl?" ); |
33 | // May be 0 because UndefinedTensorImpl doesn't get its refcount |
34 | // incremented. |
35 | const bool isUndefined = toDestroy == UndefinedTensorImpl::singleton(); |
36 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
37 | toDestroy->refcount_ == 1 || (toDestroy->refcount_ == 0 && isUndefined), |
38 | "ExclusivelyOwned<Tensor> destroyed with isUndefined " , |
39 | isUndefined, |
40 | " and refcount " , |
41 | toDestroy->refcount_, |
42 | ", expected 1 or, if isUndefined, 0!" ); |
43 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
44 | toDestroy->weakcount_ == 1 || |
45 | (toDestroy->weakcount_ == 0 && |
46 | toDestroy == UndefinedTensorImpl::singleton()), |
47 | "ExclusivelyOwned<Tensor> destroyed with isUndefined " , |
48 | isUndefined, |
49 | " and weakcount " , |
50 | toDestroy->weakcount_, |
51 | ", expected 1 or, if isUndefined, 0!" ); |
52 | if (!isUndefined) { |
53 | #ifndef NDEBUG |
54 | // Needed to pass the debug assertions in ~intrusive_ptr_target. |
55 | toDestroy->refcount_ = 0; |
56 | toDestroy->weakcount_ = 0; |
57 | #endif |
58 | delete toDestroy; |
59 | } |
60 | } |
61 | |
62 | static TensorType take(TensorType& x) { |
63 | return std::move(x); |
64 | } |
65 | |
66 | static pointer_type getImpl(repr_type& x) { |
67 | return &x; |
68 | } |
69 | |
70 | static const_pointer_type getImpl(const repr_type& x) { |
71 | return &x; |
72 | } |
73 | }; |
74 | } // namespace c10 |
75 | |