1#include <c10/util/numa.h>
2
3C10_DEFINE_bool(caffe2_cpu_numa_enabled, false, "Use NUMA whenever possible.");
4
5#if defined(__linux__) && defined(C10_USE_NUMA) && !defined(C10_MOBILE)
6#include <numa.h>
7#include <numaif.h>
8#include <unistd.h>
9#define C10_ENABLE_NUMA
10#endif
11
12// This code used to have a lot of VLOGs. However, because allocation might be
13// triggered during static initialization, it's unsafe to invoke VLOG here
14
15namespace c10 {
16
17#ifdef C10_ENABLE_NUMA
18bool IsNUMAEnabled() {
19 return FLAGS_caffe2_cpu_numa_enabled && numa_available() >= 0;
20}
21
22void NUMABind(int numa_node_id) {
23 if (numa_node_id < 0) {
24 return;
25 }
26 if (!IsNUMAEnabled()) {
27 return;
28 }
29
30 TORCH_CHECK(
31 numa_node_id <= numa_max_node(),
32 "NUMA node id ",
33 numa_node_id,
34 " is unavailable");
35
36 auto bm = numa_allocate_nodemask();
37 numa_bitmask_setbit(bm, numa_node_id);
38 numa_bind(bm);
39 numa_bitmask_free(bm);
40}
41
42int GetNUMANode(const void* ptr) {
43 if (!IsNUMAEnabled()) {
44 return -1;
45 }
46 AT_ASSERT(ptr);
47
48 int numa_node = -1;
49 TORCH_CHECK(
50 get_mempolicy(
51 &numa_node,
52 nullptr,
53 0,
54 const_cast<void*>(ptr),
55 MPOL_F_NODE | MPOL_F_ADDR) == 0,
56 "Unable to get memory policy, errno:",
57 errno);
58 return numa_node;
59}
60
61int GetNumNUMANodes() {
62 if (!IsNUMAEnabled()) {
63 return -1;
64 }
65
66 return numa_num_configured_nodes();
67}
68
69void NUMAMove(void* ptr, size_t size, int numa_node_id) {
70 if (numa_node_id < 0) {
71 return;
72 }
73 if (!IsNUMAEnabled()) {
74 return;
75 }
76 AT_ASSERT(ptr);
77
78 uintptr_t page_start_ptr =
79 ((reinterpret_cast<uintptr_t>(ptr)) & ~(getpagesize() - 1));
80 ptrdiff_t offset = reinterpret_cast<uintptr_t>(ptr) - page_start_ptr;
81 // Avoid extra dynamic allocation and NUMA api calls
82 AT_ASSERT(
83 numa_node_id >= 0 &&
84 static_cast<unsigned>(numa_node_id) < sizeof(unsigned long) * 8);
85 unsigned long mask = 1UL << numa_node_id;
86 TORCH_CHECK(
87 mbind(
88 reinterpret_cast<void*>(page_start_ptr),
89 size + offset,
90 MPOL_BIND,
91 &mask,
92 sizeof(mask) * 8,
93 MPOL_MF_MOVE | MPOL_MF_STRICT) == 0,
94 "Could not move memory to a NUMA node");
95}
96
97int GetCurrentNUMANode() {
98 if (!IsNUMAEnabled()) {
99 return -1;
100 }
101
102 auto n = numa_node_of_cpu(sched_getcpu());
103 return n;
104}
105
106#else // C10_ENABLE_NUMA
107
108bool IsNUMAEnabled() {
109 return false;
110}
111
112void NUMABind(int numa_node_id) {}
113
114int GetNUMANode(const void* ptr) {
115 return -1;
116}
117
118int GetNumNUMANodes() {
119 return -1;
120}
121
122void NUMAMove(void* ptr, size_t size, int numa_node_id) {}
123
124int GetCurrentNUMANode() {
125 return -1;
126}
127
128#endif // C10_NUMA_ENABLED
129
130} // namespace c10
131