1#pragma once
2
3#include <c10/core/TensorImpl.h>
4
5#include <utility>
6
7namespace c10 {
8// Shared ExclusivelyOwnedTraits implementation between caffe2::Tensor and
9// at::TensorBase.
10template <typename TensorType>
11struct ExclusivelyOwnedTensorTraits {
12 using repr_type = TensorType;
13 using pointer_type = TensorType*;
14 using const_pointer_type = const TensorType*;
15
16 static repr_type nullRepr() {
17 return TensorType();
18 }
19
20 template <class... Args>
21 static repr_type createInPlace(Args&&... args) {
22 return TensorType(std::forward<Args>(args)...);
23 }
24
25 static repr_type moveToRepr(TensorType&& x) {
26 return std::move(x);
27 }
28
29 static void destroyOwned(TensorType& x) {
30 TensorImpl* const toDestroy = x.unsafeReleaseTensorImpl();
31 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
32 toDestroy != nullptr, "Tensor somehow got null TensorImpl?");
33 // May be 0 because UndefinedTensorImpl doesn't get its refcount
34 // incremented.
35 const bool isUndefined = toDestroy == UndefinedTensorImpl::singleton();
36 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
37 toDestroy->refcount_ == 1 || (toDestroy->refcount_ == 0 && isUndefined),
38 "ExclusivelyOwned<Tensor> destroyed with isUndefined ",
39 isUndefined,
40 " and refcount ",
41 toDestroy->refcount_,
42 ", expected 1 or, if isUndefined, 0!");
43 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
44 toDestroy->weakcount_ == 1 ||
45 (toDestroy->weakcount_ == 0 &&
46 toDestroy == UndefinedTensorImpl::singleton()),
47 "ExclusivelyOwned<Tensor> destroyed with isUndefined ",
48 isUndefined,
49 " and weakcount ",
50 toDestroy->weakcount_,
51 ", expected 1 or, if isUndefined, 0!");
52 if (!isUndefined) {
53#ifndef NDEBUG
54 // Needed to pass the debug assertions in ~intrusive_ptr_target.
55 toDestroy->refcount_ = 0;
56 toDestroy->weakcount_ = 0;
57#endif
58 delete toDestroy;
59 }
60 }
61
62 static TensorType take(TensorType& x) {
63 return std::move(x);
64 }
65
66 static pointer_type getImpl(repr_type& x) {
67 return &x;
68 }
69
70 static const_pointer_type getImpl(const repr_type& x) {
71 return &x;
72 }
73};
74} // namespace c10
75