1#pragma once
2
3#include <c10/core/StorageImpl.h>
4
5namespace c10 {
6
7struct C10_API Storage {
8 public:
9 struct use_byte_size_t {};
10
11 Storage() = default;
12 Storage(c10::intrusive_ptr<StorageImpl> ptr)
13 : storage_impl_(std::move(ptr)) {}
14
15 // Allocates memory buffer using given allocator and creates a storage with it
16 Storage(
17 use_byte_size_t /*use_byte_size*/,
18 SymInt size_bytes,
19 Allocator* allocator = nullptr,
20 bool resizable = false)
21 : storage_impl_(c10::make_intrusive<StorageImpl>(
22 StorageImpl::use_byte_size_t(),
23 std::move(size_bytes),
24 allocator,
25 resizable)) {}
26
27 // Creates storage with pre-allocated memory buffer. Allocator is given for
28 // potential future reallocations, however it can be nullptr if the storage
29 // is non-resizable
30 Storage(
31 use_byte_size_t /*use_byte_size*/,
32 size_t size_bytes,
33 at::DataPtr data_ptr,
34 at::Allocator* allocator = nullptr,
35 bool resizable = false)
36 : storage_impl_(c10::make_intrusive<StorageImpl>(
37 StorageImpl::use_byte_size_t(),
38 size_bytes,
39 std::move(data_ptr),
40 allocator,
41 resizable)) {}
42
43 // Legacy constructor for partially initialized (dtype or memory) storages
44 // that can be temporarily created with Caffe2 APIs. See the note on top of
45 // TensorImpl.h for details.
46 static Storage create_legacy(at::Device device) {
47 auto allocator = GetAllocator(device.type());
48 return Storage(c10::make_intrusive<StorageImpl>(
49 StorageImpl::use_byte_size_t(),
50 0,
51 allocator->allocate(0), // materialize a non-default Device.
52 allocator,
53 true));
54 }
55
56 // Mimic create_legacy, but without requiring a newly-created StorageImpl.
57 void reset_legacy() {
58 TORCH_CHECK(resizable() && allocator());
59 set_nbytes(0);
60 set_data_ptr_noswap(allocator()->allocate(0));
61 }
62
63 template <typename T>
64 T* data() const {
65 return storage_impl_->data<T>();
66 }
67
68 template <typename T>
69 T* unsafe_data() const {
70 return storage_impl_->unsafe_data<T>();
71 }
72
73 // TODO: remove later
74 void set_nbytes(size_t size_bytes) const {
75 storage_impl_.get()->set_nbytes(size_bytes);
76 }
77
78 void set_nbytes(c10::SymInt size_bytes) const {
79 storage_impl_.get()->set_nbytes(std::move(size_bytes));
80 }
81
82 bool resizable() const {
83 return storage_impl_->resizable();
84 }
85
86 size_t nbytes() const {
87 return storage_impl_->nbytes();
88 }
89
90 SymInt sym_nbytes() const {
91 return storage_impl_->sym_nbytes();
92 }
93 // get() use here is to get const-correctness
94
95 void* data() const {
96 return storage_impl_.get()->data();
97 }
98
99 at::DataPtr& data_ptr() {
100 return storage_impl_->data_ptr();
101 }
102
103 const at::DataPtr& data_ptr() const {
104 return storage_impl_->data_ptr();
105 }
106
107 // Returns the previous data_ptr
108 at::DataPtr set_data_ptr(at::DataPtr&& data_ptr) const {
109 return storage_impl_.get()->set_data_ptr(std::move(data_ptr));
110 }
111
112 void set_data_ptr_noswap(at::DataPtr&& data_ptr) const {
113 return storage_impl_.get()->set_data_ptr_noswap(std::move(data_ptr));
114 }
115
116 DeviceType device_type() const {
117 return storage_impl_->device_type();
118 }
119
120 at::Allocator* allocator() const {
121 return storage_impl_.get()->allocator();
122 }
123
124 at::Device device() const {
125 return storage_impl_->device();
126 }
127
128 StorageImpl* unsafeReleaseStorageImpl() {
129 return storage_impl_.release();
130 }
131
132 StorageImpl* unsafeGetStorageImpl() const noexcept {
133 return storage_impl_.get();
134 }
135
136 c10::weak_intrusive_ptr<StorageImpl> getWeakStorageImpl() const {
137 return c10::weak_intrusive_ptr<StorageImpl>(storage_impl_);
138 }
139
140 operator bool() const {
141 return storage_impl_;
142 }
143
144 size_t use_count() const {
145 return storage_impl_.use_count();
146 }
147
148 inline bool unique() const {
149 return storage_impl_.unique();
150 }
151
152 bool is_alias_of(const Storage& other) const {
153 return storage_impl_ == other.storage_impl_;
154 }
155
156 void UniqueStorageShareExternalPointer(
157 void* src,
158 size_t capacity,
159 DeleterFnPtr d = nullptr) {
160 if (!storage_impl_.unique()) {
161 TORCH_CHECK(
162 false,
163 "UniqueStorageShareExternalPointer can only be called when use_count == 1");
164 }
165 storage_impl_->UniqueStorageShareExternalPointer(src, capacity, d);
166 }
167
168 void UniqueStorageShareExternalPointer(
169 at::DataPtr&& data_ptr,
170 size_t capacity) {
171 if (!storage_impl_.unique()) {
172 TORCH_CHECK(
173 false,
174 "UniqueStorageShareExternalPointer can only be called when use_count == 1");
175 }
176 storage_impl_->UniqueStorageShareExternalPointer(
177 std::move(data_ptr), capacity);
178 }
179
180 protected:
181 c10::intrusive_ptr<StorageImpl> storage_impl_;
182};
183
184} // namespace c10
185