1#pragma once
2
3#include <stddef.h>
4#include <memory>
5
6#include <c10/core/Device.h>
7#include <c10/util/Exception.h>
8#include <c10/util/ThreadLocalDebugInfo.h>
9#include <c10/util/UniqueVoidPtr.h>
10
11namespace c10 {
12
13// A DataPtr is a unique pointer (with an attached deleter and some
14// context for the deleter) to some memory, which also records what
15// device is for its data.
16//
17// nullptr DataPtrs can still have a nontrivial device; this allows
18// us to treat zero-size allocations uniformly with non-zero allocations.
19//
20class C10_API DataPtr {
21 private:
22 c10::detail::UniqueVoidPtr ptr_;
23 Device device_;
24
25 public:
26 // Choice of CPU here is arbitrary; if there's an "undefined" device
27 // we could use that too
28 DataPtr() : ptr_(), device_(DeviceType::CPU) {}
29 DataPtr(void* data, Device device) : ptr_(data), device_(device) {}
30 DataPtr(void* data, void* ctx, DeleterFnPtr ctx_deleter, Device device)
31 : ptr_(data, ctx, ctx_deleter), device_(device) {}
32 void* operator->() const {
33 return ptr_.get();
34 }
35 void clear() {
36 ptr_.clear();
37 }
38 void* get() const {
39 return ptr_.get();
40 }
41 void* get_context() const {
42 return ptr_.get_context();
43 }
44 void* release_context() {
45 return ptr_.release_context();
46 }
47 std::unique_ptr<void, DeleterFnPtr>&& move_context() {
48 return ptr_.move_context();
49 }
50 operator bool() const {
51 return static_cast<bool>(ptr_);
52 }
53 template <typename T>
54 T* cast_context(DeleterFnPtr expected_deleter) const {
55 return ptr_.cast_context<T>(expected_deleter);
56 }
57 DeleterFnPtr get_deleter() const {
58 return ptr_.get_deleter();
59 }
60 /**
61 * Compare the deleter in a DataPtr to expected_deleter.
62 * If it matches, replace the deleter with new_deleter
63 * and return true; otherwise, does nothing and returns
64 * false.
65 *
66 * In general, it is not safe to unconditionally set the
67 * deleter on a DataPtr, because you don't know what
68 * the deleter is, and thus will have a hard time properly
69 * disposing of the deleter without storing the original
70 * deleter (this is difficult to do, because DeleterFnPtr
71 * is not a closure, and because the context on DataPtr is
72 * only a single word, you generally don't have enough
73 * space to store both the original deleter and its context).
74 * However, in some cases, you know /exactly/ what the deleter
75 * is, and you have a new deleter that manually wraps
76 * the old one. In this case, you can safely swap the deleter
77 * after asserting that the deleters line up.
78 *
79 * What are the requirements on new_deleter? It must still
80 * properly dispose of the void* pointer passed in as its argument,
81 * where void* is whatever the context of the original deleter
82 * is. So in general, you expect the new deleter to look something
83 * like this:
84 *
85 * [](void* ptr) {
86 * some_new_stuff(ptr);
87 * get_orig_allocator()->raw_deleter(ptr);
88 * }
89 *
90 * Note that it won't work to close over the original
91 * allocator; you don't have enough space to do that! Also,
92 * it's unsafe to assume that the passed in pointer in
93 * question is the memory pointer in question; it might not
94 * be; be sure to read the source code of the Allocator
95 * in question to confirm this.
96 */
97 C10_NODISCARD bool compare_exchange_deleter(
98 DeleterFnPtr expected_deleter,
99 DeleterFnPtr new_deleter) {
100 return ptr_.compare_exchange_deleter(expected_deleter, new_deleter);
101 }
102 Device device() const {
103 return device_;
104 }
105 // Unsafely mutates the device on a DataPtr. Under normal use,
106 // you should never actually need to call this function.
107 // We need this for the implementation of the hack detailed
108 // in Note [Masquerading as CUDA]
109 void unsafe_set_device(Device device) {
110 device_ = device;
111 }
112};
113
114// NB: Device is NOT tested for here; a CUDA nullptr is as much a nullptr as a
115// CPU nullptr
116
117inline bool operator==(const DataPtr& dp, std::nullptr_t) noexcept {
118 return !dp;
119}
120inline bool operator==(std::nullptr_t, const DataPtr& dp) noexcept {
121 return !dp;
122}
123inline bool operator!=(const DataPtr& dp, std::nullptr_t) noexcept {
124 return dp;
125}
126inline bool operator!=(std::nullptr_t, const DataPtr& dp) noexcept {
127 return dp;
128}
129
130// Note [raw_allocate/raw_deallocate and Thrust]
131// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
132// Thrust's support for custom allocators requires us to write something
133// like this:
134//
135// class ThrustAllocator {
136// char* allocate(size_t);
137// void deallocate(char*, size_t);
138// };
139//
140// This is not good for our unique_ptr based allocator interface, as
141// there is no way to get to the context when we free.
142//
143// However, in some cases the context is exactly the same as
144// the data pointer. In this case, we can support the "raw"
145// allocate and deallocate interface. This is what
146// raw_deleter signifies. By default, it returns a nullptr, which means that
147// the raw interface is not implemented. Be sure to implement it whenever
148// possible, or the raw interface will incorrectly reported as unsupported,
149// when it is actually possible.
150
151struct C10_API Allocator {
152 virtual ~Allocator() = default;
153
154 virtual DataPtr allocate(size_t n) const = 0;
155
156 // If this returns a non nullptr, it means that allocate()
157 // is guaranteed to return a unique_ptr with this deleter attached;
158 // it means the rawAllocate and rawDeallocate APIs are safe to use.
159 // This function MUST always return the same BoundDeleter.
160 virtual DeleterFnPtr raw_deleter() const {
161 return nullptr;
162 }
163 void* raw_allocate(size_t n) {
164 auto dptr = allocate(n);
165 AT_ASSERT(dptr.get() == dptr.get_context());
166 return dptr.release_context();
167 }
168 void raw_deallocate(void* ptr) {
169 auto d = raw_deleter();
170 AT_ASSERT(d);
171 d(ptr);
172 }
173};
174
175// This context is used to generate DataPtr which have arbitrary
176// std::function deleters associated with them. In some user facing
177// functions, we give a (user-friendly) interface for constructing
178// tensors from external data which take an arbitrary std::function
179// deleter. Grep for InefficientStdFunctionContext to find these
180// occurrences.
181//
182// This context is inefficient because we have to do a dynamic
183// allocation InefficientStdFunctionContext, on top of the dynamic
184// allocation which is implied by std::function itself.
185struct C10_API InefficientStdFunctionContext {
186 std::unique_ptr<void, std::function<void(void*)>> ptr_;
187 InefficientStdFunctionContext(
188 std::unique_ptr<void, std::function<void(void*)>>&& ptr)
189 : ptr_(std::move(ptr)) {}
190 static DataPtr makeDataPtr(
191 void* ptr,
192 const std::function<void(void*)>& deleter,
193 Device device);
194};
195
196/** Set the allocator for DeviceType `t`. The passed in allocator pointer is
197 * expected to have static lifetime; this function does NOT take ownership
198 * of the raw pointer. (The reason for this is to prevent existing pointers
199 * to an allocator of a particular device from being invalidated when
200 * SetAllocator is called.)
201 *
202 * Also note that this is not thread-safe, and we assume this function will
203 * only be called during initialization.
204 *
205 * The 'priority' flag is introduced when we want to overwrite the default
206 * allocator, since the allocators are set statically. The default priority
207 * is 0, which means the lowest. Only higher or equal priority can overwrite
208 * existing ones.
209 */
210C10_API void SetAllocator(DeviceType t, Allocator* alloc, uint8_t priority = 0);
211C10_API Allocator* GetAllocator(const DeviceType& t);
212
213template <DeviceType t>
214struct AllocatorRegisterer {
215 explicit AllocatorRegisterer(Allocator* alloc) {
216 SetAllocator(t, alloc);
217 }
218};
219
220#define REGISTER_ALLOCATOR(t, f) \
221 namespace { \
222 static c10::AllocatorRegisterer<t> g_allocator_d(f); \
223 }
224
225// An interface for reporting thread local memory usage
226// per device
227struct C10_API MemoryReportingInfoBase : public c10::DebugInfoBase {
228 MemoryReportingInfoBase();
229 ~MemoryReportingInfoBase() override = default;
230
231 /**
232 * alloc_size corresponds to the size of the ptr.
233 *
234 * total_allocated corresponds to total allocated memory.
235 *
236 * total_reserved corresponds to total size of memory pool, both used and
237 * unused, if applicable.
238 */
239 virtual void reportMemoryUsage(
240 void* ptr,
241 int64_t alloc_size,
242 size_t total_allocated,
243 size_t total_reserved,
244 Device device) = 0;
245
246 virtual void reportOutOfMemory(
247 int64_t alloc_size,
248 size_t total_allocated,
249 size_t total_reserved,
250 Device device);
251
252 virtual bool memoryProfilingEnabled() const = 0;
253};
254
255C10_API bool memoryProfilingEnabled();
256C10_API void reportMemoryUsageToProfiler(
257 void* ptr,
258 int64_t alloc_size,
259 size_t total_allocated,
260 size_t total_reserved,
261 Device device);
262
263C10_API void reportOutOfMemoryToProfiler(
264 int64_t alloc_size,
265 size_t total_allocated,
266 size_t total_reserved,
267 Device device);
268
269} // namespace c10
270