1 | #pragma once |
2 | |
3 | #include <c10/core/StorageImpl.h> |
4 | |
5 | namespace c10 { |
6 | |
7 | struct C10_API Storage { |
8 | public: |
9 | struct use_byte_size_t {}; |
10 | |
11 | Storage() = default; |
12 | Storage(c10::intrusive_ptr<StorageImpl> ptr) |
13 | : storage_impl_(std::move(ptr)) {} |
14 | |
15 | // Allocates memory buffer using given allocator and creates a storage with it |
16 | Storage( |
17 | use_byte_size_t /*use_byte_size*/, |
18 | SymInt size_bytes, |
19 | Allocator* allocator = nullptr, |
20 | bool resizable = false) |
21 | : storage_impl_(c10::make_intrusive<StorageImpl>( |
22 | StorageImpl::use_byte_size_t(), |
23 | std::move(size_bytes), |
24 | allocator, |
25 | resizable)) {} |
26 | |
27 | // Creates storage with pre-allocated memory buffer. Allocator is given for |
28 | // potential future reallocations, however it can be nullptr if the storage |
29 | // is non-resizable |
30 | Storage( |
31 | use_byte_size_t /*use_byte_size*/, |
32 | size_t size_bytes, |
33 | at::DataPtr data_ptr, |
34 | at::Allocator* allocator = nullptr, |
35 | bool resizable = false) |
36 | : storage_impl_(c10::make_intrusive<StorageImpl>( |
37 | StorageImpl::use_byte_size_t(), |
38 | size_bytes, |
39 | std::move(data_ptr), |
40 | allocator, |
41 | resizable)) {} |
42 | |
43 | // Legacy constructor for partially initialized (dtype or memory) storages |
44 | // that can be temporarily created with Caffe2 APIs. See the note on top of |
45 | // TensorImpl.h for details. |
46 | static Storage create_legacy(at::Device device) { |
47 | auto allocator = GetAllocator(device.type()); |
48 | return Storage(c10::make_intrusive<StorageImpl>( |
49 | StorageImpl::use_byte_size_t(), |
50 | 0, |
51 | allocator->allocate(0), // materialize a non-default Device. |
52 | allocator, |
53 | true)); |
54 | } |
55 | |
56 | // Mimic create_legacy, but without requiring a newly-created StorageImpl. |
57 | void reset_legacy() { |
58 | TORCH_CHECK(resizable() && allocator()); |
59 | set_nbytes(0); |
60 | set_data_ptr_noswap(allocator()->allocate(0)); |
61 | } |
62 | |
63 | template <typename T> |
64 | T* data() const { |
65 | return storage_impl_->data<T>(); |
66 | } |
67 | |
68 | template <typename T> |
69 | T* unsafe_data() const { |
70 | return storage_impl_->unsafe_data<T>(); |
71 | } |
72 | |
73 | // TODO: remove later |
74 | void set_nbytes(size_t size_bytes) const { |
75 | storage_impl_.get()->set_nbytes(size_bytes); |
76 | } |
77 | |
78 | void set_nbytes(c10::SymInt size_bytes) const { |
79 | storage_impl_.get()->set_nbytes(std::move(size_bytes)); |
80 | } |
81 | |
82 | bool resizable() const { |
83 | return storage_impl_->resizable(); |
84 | } |
85 | |
86 | size_t nbytes() const { |
87 | return storage_impl_->nbytes(); |
88 | } |
89 | |
90 | SymInt sym_nbytes() const { |
91 | return storage_impl_->sym_nbytes(); |
92 | } |
93 | // get() use here is to get const-correctness |
94 | |
95 | void* data() const { |
96 | return storage_impl_.get()->data(); |
97 | } |
98 | |
99 | at::DataPtr& data_ptr() { |
100 | return storage_impl_->data_ptr(); |
101 | } |
102 | |
103 | const at::DataPtr& data_ptr() const { |
104 | return storage_impl_->data_ptr(); |
105 | } |
106 | |
107 | // Returns the previous data_ptr |
108 | at::DataPtr set_data_ptr(at::DataPtr&& data_ptr) const { |
109 | return storage_impl_.get()->set_data_ptr(std::move(data_ptr)); |
110 | } |
111 | |
112 | void set_data_ptr_noswap(at::DataPtr&& data_ptr) const { |
113 | return storage_impl_.get()->set_data_ptr_noswap(std::move(data_ptr)); |
114 | } |
115 | |
116 | DeviceType device_type() const { |
117 | return storage_impl_->device_type(); |
118 | } |
119 | |
120 | at::Allocator* allocator() const { |
121 | return storage_impl_.get()->allocator(); |
122 | } |
123 | |
124 | at::Device device() const { |
125 | return storage_impl_->device(); |
126 | } |
127 | |
128 | StorageImpl* unsafeReleaseStorageImpl() { |
129 | return storage_impl_.release(); |
130 | } |
131 | |
132 | StorageImpl* unsafeGetStorageImpl() const noexcept { |
133 | return storage_impl_.get(); |
134 | } |
135 | |
136 | c10::weak_intrusive_ptr<StorageImpl> getWeakStorageImpl() const { |
137 | return c10::weak_intrusive_ptr<StorageImpl>(storage_impl_); |
138 | } |
139 | |
140 | operator bool() const { |
141 | return storage_impl_; |
142 | } |
143 | |
144 | size_t use_count() const { |
145 | return storage_impl_.use_count(); |
146 | } |
147 | |
148 | inline bool unique() const { |
149 | return storage_impl_.unique(); |
150 | } |
151 | |
152 | bool is_alias_of(const Storage& other) const { |
153 | return storage_impl_ == other.storage_impl_; |
154 | } |
155 | |
156 | void UniqueStorageShareExternalPointer( |
157 | void* src, |
158 | size_t capacity, |
159 | DeleterFnPtr d = nullptr) { |
160 | if (!storage_impl_.unique()) { |
161 | TORCH_CHECK( |
162 | false, |
163 | "UniqueStorageShareExternalPointer can only be called when use_count == 1" ); |
164 | } |
165 | storage_impl_->UniqueStorageShareExternalPointer(src, capacity, d); |
166 | } |
167 | |
168 | void UniqueStorageShareExternalPointer( |
169 | at::DataPtr&& data_ptr, |
170 | size_t capacity) { |
171 | if (!storage_impl_.unique()) { |
172 | TORCH_CHECK( |
173 | false, |
174 | "UniqueStorageShareExternalPointer can only be called when use_count == 1" ); |
175 | } |
176 | storage_impl_->UniqueStorageShareExternalPointer( |
177 | std::move(data_ptr), capacity); |
178 | } |
179 | |
180 | protected: |
181 | c10::intrusive_ptr<StorageImpl> storage_impl_; |
182 | }; |
183 | |
184 | } // namespace c10 |
185 | |