1 | #pragma once |
2 | |
3 | #include <c10/core/Allocator.h> |
4 | #include <c10/core/ScalarType.h> |
5 | #include <c10/core/SymInt.h> |
6 | |
7 | #include <c10/util/intrusive_ptr.h> |
8 | |
9 | namespace c10 { |
10 | |
11 | // A storage represents the underlying backing data buffer for a |
12 | // tensor. This concept was inherited from the original Torch7 |
13 | // codebase; we'd kind of like to get rid of the concept |
14 | // (see https://github.com/pytorch/pytorch/issues/14797) but |
15 | // it's hard work and no one has gotten around to doing it. |
16 | // |
17 | // NB: storage is supposed to uniquely own a data pointer; e.g., |
18 | // two non-null data pointers alias if and only if they are from |
19 | // the same storage. Technically you can violate this invariant |
20 | // (e.g., you can create a non-owning StorageImpl with at::from_blob) |
21 | // but a lot of things won't work correctly, including: |
22 | // |
23 | // - An ordinary deleter on such a storage is wrong, because normal deleters |
24 | // assume unique ownership, but if you have two storages at the same data, |
25 | // that implies there is some sort of shared ownership. So your deleter would |
26 | // have to actually be internally doing some sort of refcount thing |
27 | // - Deepcopy in Python side relies on storage equality and not data pointer |
28 | // equality; so if there are two separate storages pointing to the same data, |
29 | // the data will actually get duplicated in that case (one data ptr before, |
30 | // two data ptrs after) |
31 | // - Version counts won't work correctly, because we do all VC tracking at the |
32 | // level of storages (unless you explicitly disconnect the VC with detach); |
33 | // mutation because data pointers are the same are totally untracked |
34 | struct C10_API StorageImpl : public c10::intrusive_ptr_target { |
35 | public: |
36 | struct use_byte_size_t {}; |
37 | |
38 | StorageImpl( |
39 | use_byte_size_t /*use_byte_size*/, |
40 | SymInt size_bytes, |
41 | at::DataPtr data_ptr, |
42 | at::Allocator* allocator, |
43 | bool resizable) |
44 | : data_ptr_(std::move(data_ptr)), |
45 | size_bytes_(std::move(size_bytes)), |
46 | size_bytes_is_symbolic_(size_bytes_.is_symbolic()), |
47 | resizable_(resizable), |
48 | received_cuda_(false), |
49 | allocator_(allocator) { |
50 | if (resizable) { |
51 | TORCH_INTERNAL_ASSERT( |
52 | allocator_, "For resizable storage, allocator must be provided" ); |
53 | } |
54 | } |
55 | |
56 | StorageImpl( |
57 | use_byte_size_t /*use_byte_size*/, |
58 | SymInt size_bytes, |
59 | at::Allocator* allocator, |
60 | bool resizable) |
61 | : StorageImpl( |
62 | use_byte_size_t(), |
63 | size_bytes, |
64 | size_bytes.is_symbolic() |
65 | ? allocator->allocate(0) |
66 | : allocator->allocate(size_bytes.as_int_unchecked()), |
67 | allocator, |
68 | resizable) {} |
69 | |
70 | StorageImpl& operator=(StorageImpl&& other) = default; |
71 | StorageImpl& operator=(const StorageImpl&) = delete; |
72 | StorageImpl() = delete; |
73 | StorageImpl(StorageImpl&& other) = default; |
74 | StorageImpl(const StorageImpl&) = delete; |
75 | ~StorageImpl() override = default; |
76 | |
77 | void reset() { |
78 | data_ptr_.clear(); |
79 | size_bytes_ = 0; |
80 | size_bytes_is_symbolic_ = false; |
81 | } |
82 | |
83 | template <typename T> |
84 | inline T* data() const { |
85 | return unsafe_data<T>(); |
86 | } |
87 | |
88 | template <typename T> |
89 | inline T* unsafe_data() const { |
90 | return static_cast<T*>(this->data_ptr_.get()); |
91 | } |
92 | |
93 | // Destructor doesn't call release_resources because it's |
94 | // unnecessary; don't forget to change that if needed! |
95 | void release_resources() override { |
96 | data_ptr_.clear(); |
97 | } |
98 | |
99 | size_t nbytes() const { |
100 | TORCH_CHECK(!size_bytes_is_symbolic_); |
101 | return size_bytes_.as_int_unchecked(); |
102 | } |
103 | |
104 | SymInt sym_nbytes() const { |
105 | return size_bytes_; |
106 | } |
107 | |
108 | // TODO: remove later |
109 | void set_nbytes(size_t size_bytes) { |
110 | size_bytes_ = size_bytes; |
111 | size_bytes_is_symbolic_ = false; |
112 | } |
113 | |
114 | void set_nbytes(c10::SymInt size_bytes) { |
115 | size_bytes_ = std::move(size_bytes); |
116 | } |
117 | |
118 | bool resizable() const { |
119 | return resizable_; |
120 | }; |
121 | |
122 | at::DataPtr& data_ptr() { |
123 | return data_ptr_; |
124 | }; |
125 | |
126 | const at::DataPtr& data_ptr() const { |
127 | return data_ptr_; |
128 | }; |
129 | |
130 | // Returns the previous data_ptr |
131 | at::DataPtr set_data_ptr(at::DataPtr&& data_ptr) { |
132 | at::DataPtr old_data_ptr(std::move(data_ptr_)); |
133 | data_ptr_ = std::move(data_ptr); |
134 | return old_data_ptr; |
135 | }; |
136 | |
137 | void set_data_ptr_noswap(at::DataPtr&& data_ptr) { |
138 | data_ptr_ = std::move(data_ptr); |
139 | } |
140 | |
141 | // TODO: Return const ptr eventually if possible |
142 | void* data() { |
143 | return data_ptr_.get(); |
144 | } |
145 | |
146 | void* data() const { |
147 | return data_ptr_.get(); |
148 | } |
149 | |
150 | at::DeviceType device_type() const { |
151 | return data_ptr_.device().type(); |
152 | } |
153 | |
154 | at::Allocator* allocator() { |
155 | return allocator_; |
156 | } |
157 | |
158 | const at::Allocator* allocator() const { |
159 | return allocator_; |
160 | }; |
161 | |
162 | // You generally shouldn't use this method, but it is occasionally |
163 | // useful if you want to override how a tensor will be reallocated, |
164 | // after it was already allocated (and its initial allocator was |
165 | // set) |
166 | void set_allocator(at::Allocator* allocator) { |
167 | allocator_ = allocator; |
168 | } |
169 | |
170 | Device device() const { |
171 | return data_ptr_.device(); |
172 | } |
173 | |
174 | void set_resizable(bool resizable) { |
175 | if (resizable) { |
176 | // We need an allocator to be resizable |
177 | AT_ASSERT(allocator_); |
178 | } |
179 | resizable_ = resizable; |
180 | } |
181 | |
182 | /** |
183 | * Can only be called when use_count is 1 |
184 | */ |
185 | void UniqueStorageShareExternalPointer( |
186 | void* src, |
187 | size_t size_bytes, |
188 | DeleterFnPtr d = nullptr) { |
189 | UniqueStorageShareExternalPointer( |
190 | at::DataPtr(src, src, d, data_ptr_.device()), size_bytes); |
191 | } |
192 | |
193 | /** |
194 | * Can only be called when use_count is 1 |
195 | */ |
196 | void UniqueStorageShareExternalPointer( |
197 | at::DataPtr&& data_ptr, |
198 | size_t size_bytes) { |
199 | data_ptr_ = std::move(data_ptr); |
200 | size_bytes_ = size_bytes; |
201 | size_bytes_is_symbolic_ = false; |
202 | allocator_ = nullptr; |
203 | resizable_ = false; |
204 | } |
205 | |
206 | // This method can be used only after storage construction and cannot be used |
207 | // to modify storage status |
208 | void set_received_cuda(bool received_cuda) { |
209 | received_cuda_ = received_cuda; |
210 | } |
211 | |
212 | bool received_cuda() { |
213 | return received_cuda_; |
214 | } |
215 | |
216 | private: |
217 | DataPtr data_ptr_; |
218 | SymInt size_bytes_; |
219 | bool size_bytes_is_symbolic_; |
220 | bool resizable_; |
221 | // Identifies that Storage was received from another process and doesn't have |
222 | // local to process cuda memory allocation |
223 | bool received_cuda_; |
224 | Allocator* allocator_; |
225 | }; |
226 | } // namespace c10 |
227 | |