1#include <c10/core/Allocator.h>
2
3#include <c10/util/ThreadLocalDebugInfo.h>
4
5namespace c10 {
6
7static void deleteInefficientStdFunctionContext(void* ptr) {
8 delete static_cast<InefficientStdFunctionContext*>(ptr);
9}
10
11at::DataPtr InefficientStdFunctionContext::makeDataPtr(
12 void* ptr,
13 const std::function<void(void*)>& deleter,
14 Device device) {
15 return {
16 ptr,
17 new InefficientStdFunctionContext({ptr, deleter}),
18 &deleteInefficientStdFunctionContext,
19 device};
20}
21
22// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
23C10_API at::Allocator* allocator_array[at::COMPILE_TIME_MAX_DEVICE_TYPES];
24// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
25C10_API uint8_t allocator_priority[at::COMPILE_TIME_MAX_DEVICE_TYPES] = {0};
26
27void SetAllocator(at::DeviceType t, at::Allocator* alloc, uint8_t priority) {
28 if (priority >= allocator_priority[static_cast<int>(t)]) {
29 allocator_array[static_cast<int>(t)] = alloc;
30 allocator_priority[static_cast<int>(t)] = priority;
31 }
32}
33
34at::Allocator* GetAllocator(const at::DeviceType& t) {
35 auto* alloc = allocator_array[static_cast<int>(t)];
36 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(alloc, "Allocator for ", t, " is not set.");
37 return alloc;
38}
39
40bool memoryProfilingEnabled() {
41 auto* reporter_ptr = static_cast<MemoryReportingInfoBase*>(
42 ThreadLocalDebugInfo::get(DebugInfoKind::PROFILER_STATE));
43 return reporter_ptr && reporter_ptr->memoryProfilingEnabled();
44}
45
46void reportMemoryUsageToProfiler(
47 void* ptr,
48 int64_t alloc_size,
49 size_t total_allocated,
50 size_t total_reserved,
51 Device device) {
52 auto* reporter_ptr = static_cast<MemoryReportingInfoBase*>(
53 ThreadLocalDebugInfo::get(DebugInfoKind::PROFILER_STATE));
54 if (reporter_ptr) {
55 reporter_ptr->reportMemoryUsage(
56 ptr, alloc_size, total_allocated, total_reserved, device);
57 }
58}
59
60void reportOutOfMemoryToProfiler(
61 int64_t alloc_size,
62 size_t total_allocated,
63 size_t total_reserved,
64 Device device) {
65 auto* reporter_ptr = static_cast<MemoryReportingInfoBase*>(
66 ThreadLocalDebugInfo::get(DebugInfoKind::PROFILER_STATE));
67 if (reporter_ptr) {
68 reporter_ptr->reportOutOfMemory(
69 alloc_size, total_allocated, total_reserved, device);
70 }
71}
72
73MemoryReportingInfoBase::MemoryReportingInfoBase() = default;
74
75void MemoryReportingInfoBase::reportOutOfMemory(
76 int64_t /*alloc_size*/,
77 size_t /*total_allocated*/,
78 size_t /*total_reserved*/,
79 Device /*device*/) {}
80
81} // namespace c10
82