1#pragma once
2
3#include <c10/core/Allocator.h>
4#include <c10/core/ScalarType.h>
5#include <c10/core/SymInt.h>
6
7#include <c10/util/intrusive_ptr.h>
8
9namespace c10 {
10
11// A storage represents the underlying backing data buffer for a
12// tensor. This concept was inherited from the original Torch7
13// codebase; we'd kind of like to get rid of the concept
14// (see https://github.com/pytorch/pytorch/issues/14797) but
15// it's hard work and no one has gotten around to doing it.
16//
17// NB: storage is supposed to uniquely own a data pointer; e.g.,
18// two non-null data pointers alias if and only if they are from
19// the same storage. Technically you can violate this invariant
20// (e.g., you can create a non-owning StorageImpl with at::from_blob)
21// but a lot of things won't work correctly, including:
22//
23// - An ordinary deleter on such a storage is wrong, because normal deleters
24// assume unique ownership, but if you have two storages at the same data,
25// that implies there is some sort of shared ownership. So your deleter would
26// have to actually be internally doing some sort of refcount thing
27// - Deepcopy in Python side relies on storage equality and not data pointer
28// equality; so if there are two separate storages pointing to the same data,
29// the data will actually get duplicated in that case (one data ptr before,
30// two data ptrs after)
31// - Version counts won't work correctly, because we do all VC tracking at the
32// level of storages (unless you explicitly disconnect the VC with detach);
33// mutation because data pointers are the same are totally untracked
34struct C10_API StorageImpl : public c10::intrusive_ptr_target {
35 public:
36 struct use_byte_size_t {};
37
38 StorageImpl(
39 use_byte_size_t /*use_byte_size*/,
40 SymInt size_bytes,
41 at::DataPtr data_ptr,
42 at::Allocator* allocator,
43 bool resizable)
44 : data_ptr_(std::move(data_ptr)),
45 size_bytes_(std::move(size_bytes)),
46 size_bytes_is_symbolic_(size_bytes_.is_symbolic()),
47 resizable_(resizable),
48 received_cuda_(false),
49 allocator_(allocator) {
50 if (resizable) {
51 TORCH_INTERNAL_ASSERT(
52 allocator_, "For resizable storage, allocator must be provided");
53 }
54 }
55
56 StorageImpl(
57 use_byte_size_t /*use_byte_size*/,
58 SymInt size_bytes,
59 at::Allocator* allocator,
60 bool resizable)
61 : StorageImpl(
62 use_byte_size_t(),
63 size_bytes,
64 size_bytes.is_symbolic()
65 ? allocator->allocate(0)
66 : allocator->allocate(size_bytes.as_int_unchecked()),
67 allocator,
68 resizable) {}
69
70 StorageImpl& operator=(StorageImpl&& other) = default;
71 StorageImpl& operator=(const StorageImpl&) = delete;
72 StorageImpl() = delete;
73 StorageImpl(StorageImpl&& other) = default;
74 StorageImpl(const StorageImpl&) = delete;
75 ~StorageImpl() override = default;
76
77 void reset() {
78 data_ptr_.clear();
79 size_bytes_ = 0;
80 size_bytes_is_symbolic_ = false;
81 }
82
83 template <typename T>
84 inline T* data() const {
85 return unsafe_data<T>();
86 }
87
88 template <typename T>
89 inline T* unsafe_data() const {
90 return static_cast<T*>(this->data_ptr_.get());
91 }
92
93 // Destructor doesn't call release_resources because it's
94 // unnecessary; don't forget to change that if needed!
95 void release_resources() override {
96 data_ptr_.clear();
97 }
98
99 size_t nbytes() const {
100 TORCH_CHECK(!size_bytes_is_symbolic_);
101 return size_bytes_.as_int_unchecked();
102 }
103
104 SymInt sym_nbytes() const {
105 return size_bytes_;
106 }
107
108 // TODO: remove later
109 void set_nbytes(size_t size_bytes) {
110 size_bytes_ = size_bytes;
111 size_bytes_is_symbolic_ = false;
112 }
113
114 void set_nbytes(c10::SymInt size_bytes) {
115 size_bytes_ = std::move(size_bytes);
116 }
117
118 bool resizable() const {
119 return resizable_;
120 };
121
122 at::DataPtr& data_ptr() {
123 return data_ptr_;
124 };
125
126 const at::DataPtr& data_ptr() const {
127 return data_ptr_;
128 };
129
130 // Returns the previous data_ptr
131 at::DataPtr set_data_ptr(at::DataPtr&& data_ptr) {
132 at::DataPtr old_data_ptr(std::move(data_ptr_));
133 data_ptr_ = std::move(data_ptr);
134 return old_data_ptr;
135 };
136
137 void set_data_ptr_noswap(at::DataPtr&& data_ptr) {
138 data_ptr_ = std::move(data_ptr);
139 }
140
141 // TODO: Return const ptr eventually if possible
142 void* data() {
143 return data_ptr_.get();
144 }
145
146 void* data() const {
147 return data_ptr_.get();
148 }
149
150 at::DeviceType device_type() const {
151 return data_ptr_.device().type();
152 }
153
154 at::Allocator* allocator() {
155 return allocator_;
156 }
157
158 const at::Allocator* allocator() const {
159 return allocator_;
160 };
161
162 // You generally shouldn't use this method, but it is occasionally
163 // useful if you want to override how a tensor will be reallocated,
164 // after it was already allocated (and its initial allocator was
165 // set)
166 void set_allocator(at::Allocator* allocator) {
167 allocator_ = allocator;
168 }
169
170 Device device() const {
171 return data_ptr_.device();
172 }
173
174 void set_resizable(bool resizable) {
175 if (resizable) {
176 // We need an allocator to be resizable
177 AT_ASSERT(allocator_);
178 }
179 resizable_ = resizable;
180 }
181
182 /**
183 * Can only be called when use_count is 1
184 */
185 void UniqueStorageShareExternalPointer(
186 void* src,
187 size_t size_bytes,
188 DeleterFnPtr d = nullptr) {
189 UniqueStorageShareExternalPointer(
190 at::DataPtr(src, src, d, data_ptr_.device()), size_bytes);
191 }
192
193 /**
194 * Can only be called when use_count is 1
195 */
196 void UniqueStorageShareExternalPointer(
197 at::DataPtr&& data_ptr,
198 size_t size_bytes) {
199 data_ptr_ = std::move(data_ptr);
200 size_bytes_ = size_bytes;
201 size_bytes_is_symbolic_ = false;
202 allocator_ = nullptr;
203 resizable_ = false;
204 }
205
206 // This method can be used only after storage construction and cannot be used
207 // to modify storage status
208 void set_received_cuda(bool received_cuda) {
209 received_cuda_ = received_cuda;
210 }
211
212 bool received_cuda() {
213 return received_cuda_;
214 }
215
216 private:
217 DataPtr data_ptr_;
218 SymInt size_bytes_;
219 bool size_bytes_is_symbolic_;
220 bool resizable_;
221 // Identifies that Storage was received from another process and doesn't have
222 // local to process cuda memory allocation
223 bool received_cuda_;
224 Allocator* allocator_;
225};
226} // namespace c10
227