1 | #include <c10/core/Allocator.h> |
2 | |
3 | #include <c10/util/ThreadLocalDebugInfo.h> |
4 | |
5 | namespace c10 { |
6 | |
7 | static void deleteInefficientStdFunctionContext(void* ptr) { |
8 | delete static_cast<InefficientStdFunctionContext*>(ptr); |
9 | } |
10 | |
11 | at::DataPtr InefficientStdFunctionContext::makeDataPtr( |
12 | void* ptr, |
13 | const std::function<void(void*)>& deleter, |
14 | Device device) { |
15 | return { |
16 | ptr, |
17 | new InefficientStdFunctionContext({ptr, deleter}), |
18 | &deleteInefficientStdFunctionContext, |
19 | device}; |
20 | } |
21 | |
22 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) |
23 | C10_API at::Allocator* allocator_array[at::COMPILE_TIME_MAX_DEVICE_TYPES]; |
24 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) |
25 | C10_API uint8_t allocator_priority[at::COMPILE_TIME_MAX_DEVICE_TYPES] = {0}; |
26 | |
27 | void SetAllocator(at::DeviceType t, at::Allocator* alloc, uint8_t priority) { |
28 | if (priority >= allocator_priority[static_cast<int>(t)]) { |
29 | allocator_array[static_cast<int>(t)] = alloc; |
30 | allocator_priority[static_cast<int>(t)] = priority; |
31 | } |
32 | } |
33 | |
34 | at::Allocator* GetAllocator(const at::DeviceType& t) { |
35 | auto* alloc = allocator_array[static_cast<int>(t)]; |
36 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(alloc, "Allocator for " , t, " is not set." ); |
37 | return alloc; |
38 | } |
39 | |
40 | bool memoryProfilingEnabled() { |
41 | auto* reporter_ptr = static_cast<MemoryReportingInfoBase*>( |
42 | ThreadLocalDebugInfo::get(DebugInfoKind::PROFILER_STATE)); |
43 | return reporter_ptr && reporter_ptr->memoryProfilingEnabled(); |
44 | } |
45 | |
46 | void reportMemoryUsageToProfiler( |
47 | void* ptr, |
48 | int64_t alloc_size, |
49 | size_t total_allocated, |
50 | size_t total_reserved, |
51 | Device device) { |
52 | auto* reporter_ptr = static_cast<MemoryReportingInfoBase*>( |
53 | ThreadLocalDebugInfo::get(DebugInfoKind::PROFILER_STATE)); |
54 | if (reporter_ptr) { |
55 | reporter_ptr->reportMemoryUsage( |
56 | ptr, alloc_size, total_allocated, total_reserved, device); |
57 | } |
58 | } |
59 | |
60 | void reportOutOfMemoryToProfiler( |
61 | int64_t alloc_size, |
62 | size_t total_allocated, |
63 | size_t total_reserved, |
64 | Device device) { |
65 | auto* reporter_ptr = static_cast<MemoryReportingInfoBase*>( |
66 | ThreadLocalDebugInfo::get(DebugInfoKind::PROFILER_STATE)); |
67 | if (reporter_ptr) { |
68 | reporter_ptr->reportOutOfMemory( |
69 | alloc_size, total_allocated, total_reserved, device); |
70 | } |
71 | } |
72 | |
73 | MemoryReportingInfoBase::MemoryReportingInfoBase() = default; |
74 | |
75 | void MemoryReportingInfoBase::reportOutOfMemory( |
76 | int64_t /*alloc_size*/, |
77 | size_t /*total_allocated*/, |
78 | size_t /*total_reserved*/, |
79 | Device /*device*/) {} |
80 | |
81 | } // namespace c10 |
82 | |