1/*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17#include "memory_debug.hpp"
18
19#ifdef _WIN32
20#define WIN32_LEAN_AND_MEAN
21#include <windows.h>
22#endif
23
24#if defined __unix__ || defined __APPLE__ || defined __FreeBSD__ \
25 || defined __Fuchsia__
26#include <unistd.h>
27#include <sys/mman.h>
28#endif
29
30#include <assert.h>
31
32#include "dnnl_thread.hpp"
33#include "nstl.hpp"
34#include "utils.hpp"
35
36namespace dnnl {
37namespace impl {
38namespace memory_debug {
39
40template <typename T>
41static inline T get_page_start(const void *ptr) {
42 size_t page_mask = ~(getpagesize() - 1);
43 size_t ptr_cast = reinterpret_cast<size_t>(ptr);
44 return reinterpret_cast<T>(ptr_cast & page_mask);
45}
46
47template <typename T>
48static inline T get_page_end(const void *ptr) {
49 size_t page_mask = ~(getpagesize() - 1);
50 size_t ptr_cast = reinterpret_cast<size_t>(ptr);
51 return reinterpret_cast<T>((ptr_cast + getpagesize() - 1) & page_mask);
52}
53
54static inline int num_protect_pages() {
55 if (is_mem_debug())
56 return DNNL_MEM_DEBUG_PROTECT_SIZE;
57 else
58 return 0;
59}
60
61size_t protect_size() {
62 return (size_t)num_protect_pages() * getpagesize();
63}
64
65#ifdef _WIN32
66#define PROT_NONE 0
67#define PROT_READ 1
68#define PROT_WRITE 2
69static inline int mprotect(void *addr, size_t len, int prot) {
70 // TODO: Create a mprotect emulation layer to improve debug scratchpad
71 // support on windows. This should require the windows.h and memoryapi.h
72 // headers
73 return 0;
74}
75#endif
76
77struct memory_tag_t {
78 void *memory_start;
79 size_t buffer_size;
80};
81
82static inline memory_tag_t *get_memory_tags(void *ptr) {
83 return get_page_start<memory_tag_t *>(ptr) - 1;
84}
85
86void *malloc(size_t size, int alignment) {
87 void *ptr;
88
89 size_t buffer_size = utils::rnd_up(size, alignment);
90 int buffer_alignment = alignment;
91 if (buffer_alignment < getpagesize()) alignment = getpagesize();
92 size = utils::rnd_up(
93 size + alignment + 2 * protect_size(), (size_t)alignment);
94
95#ifdef _WIN32
96 ptr = _aligned_malloc(size, alignment);
97 int rc = ptr ? 0 : -1;
98#else
99 int rc = ::posix_memalign(&ptr, alignment, size);
100#endif
101
102 if (rc == 0) {
103 void *mem_start = ptr;
104 ptr = utils::align_ptr(
105 reinterpret_cast<char *>(ptr) + protect_size(), alignment);
106 if (is_mem_debug_overflow()) {
107 size_t offset = (alignment - (buffer_size % alignment)) % alignment;
108 ptr = reinterpret_cast<char *>(ptr) + offset;
109 }
110 assert(protect_size() >= 16);
111 memory_tag_t *tag = get_memory_tags(ptr);
112 tag->memory_start = mem_start;
113 tag->buffer_size = buffer_size;
114 protect_buffer(ptr, buffer_size, engine_kind_t::dnnl_cpu);
115 }
116
117 return (rc == 0) ? ptr : nullptr;
118}
119
120void free(void *p) {
121 if (p != nullptr) {
122 memory_tag_t *tag = get_memory_tags(p);
123 int status;
124 MAYBE_UNUSED(status);
125
126 status = mprotect(get_page_start<void *>(tag), getpagesize(),
127 PROT_WRITE | PROT_READ);
128 assert(status == 0);
129 unprotect_buffer(p, tag->buffer_size, engine_kind_t::dnnl_cpu);
130
131 p = tag->memory_start;
132 }
133
134#ifdef _WIN32
135 _aligned_free(p);
136#else
137 ::free(p);
138#endif
139}
140
141// Assumes the input buffer is allocated such that there is num_protect_pages()
142// pages surrounding the buffer
143void protect_buffer(void *addr, size_t size, engine_kind_t engine_kind) {
144 if (engine_kind != engine_kind_t::dnnl_cpu)
145 return; // Only CPU is supported currently
146
147 char *page_start = get_page_start<char *>(addr);
148 char *page_end
149 = get_page_end<char *>(reinterpret_cast<const char *>(addr) + size);
150 int status;
151 MAYBE_UNUSED(status);
152
153 status = mprotect(page_start - protect_size(), protect_size(), PROT_NONE);
154 assert(status == 0);
155 status = mprotect(page_end, protect_size(), PROT_NONE);
156 assert(status == 0);
157
158 // The canary is set so that it will generate NaN for floating point
159 // data types. This causes uninitialized memory usage on floating point
160 // data to be poisoned, increasing the chance the error is caught.
161 uint16_t canary = 0x7ff1;
162 size_t work_amount = (size_t)((page_end - page_start) / getpagesize());
163 if (work_amount <= 1) {
164 // Avoid large memory initializations for small buffers
165 uint16_t *ptr_start = reinterpret_cast<uint16_t *>(
166 reinterpret_cast<size_t>(addr) & ~1);
167 uint16_t *ptr_end = reinterpret_cast<uint16_t *>(
168 reinterpret_cast<char *>(addr) + size);
169 for (uint16_t *curr = ptr_start; curr < ptr_end; curr++) {
170 *curr = canary;
171 }
172 } else {
173 parallel(0, [&](const int ithr, const int nthr) {
174 size_t start = 0, end = 0;
175 balance211(work_amount, nthr, ithr, start, end);
176 uint16_t *ptr_start = reinterpret_cast<uint16_t *>(
177 page_start + getpagesize() * start);
178 uint16_t *ptr_end = reinterpret_cast<uint16_t *>(
179 page_start + getpagesize() * end);
180
181 for (uint16_t *curr = ptr_start; curr < ptr_end; curr++) {
182 *curr = canary;
183 }
184 });
185 }
186}
187
188// Assumes the input buffer is allocated such that there is num_protect_pages()
189// pages surrounding the buffer
190void unprotect_buffer(
191 const void *addr, size_t size, engine_kind_t engine_kind) {
192 if (engine_kind != engine_kind_t::dnnl_cpu)
193 return; // Only CPU is supported currently
194
195 char *page_start = get_page_start<char *>(addr);
196 char *page_end
197 = get_page_end<char *>(reinterpret_cast<const char *>(addr) + size);
198 int status;
199 MAYBE_UNUSED(status);
200
201 status = mprotect(page_start - protect_size(), protect_size(),
202 PROT_WRITE | PROT_READ);
203 assert(status == 0);
204 status = mprotect(page_end, protect_size(), PROT_WRITE | PROT_READ);
205 assert(status == 0);
206}
207
208} // namespace memory_debug
209} // namespace impl
210} // namespace dnnl
211