1 | #include <torch/csrc/DataLoader.h> |
2 | |
3 | // Together with `torch/utils/data/_utils/signal_handling.py`, the following |
4 | // is an effort to do our best to provide some error message to users when a |
5 | // worker dies due to error / critical signals. |
6 | // |
7 | // See NOTE [ Signal handling in multiprocessing data loading ] for more |
8 | // details. |
9 | |
10 | // TODO: The following don't work on Windows. Specifically, sigaction, waitid |
11 | // calls, and SIGCHLD handler. Currently, dummy implementations are provided |
12 | // for Windows. |
13 | |
14 | #ifndef _WIN32 |
15 | |
16 | #include <torch/csrc/Exceptions.h> |
17 | #include <torch/csrc/utils/python_numbers.h> |
18 | |
19 | #include <c10/util/irange.h> |
20 | #include <fmt/format.h> |
21 | |
22 | #include <sys/wait.h> |
23 | #include <atomic> |
24 | #include <csignal> |
25 | #include <map> |
26 | #include <set> |
27 | #include <sstream> |
28 | |
29 | using namespace torch; |
30 | |
31 | // Critical signal handlers should be registered on worker processes before |
32 | // doing work. |
33 | // The handler will raise default handler so that the kill information will be |
34 | // retrieved from main process. |
35 | // Python handle is _set_worker_signal_handlers(). |
36 | #define SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG) \ |
37 | static void HANDLER_NAME(int sig, siginfo_t* info, void* ctx) { \ |
38 | auto _w = \ |
39 | write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char)); \ |
40 | (void)_w; \ |
41 | struct sigaction sa {}; \ |
42 | sa.sa_handler = SIG_DFL; \ |
43 | sa.sa_flags = 0; \ |
44 | if (sigemptyset(&sa.sa_mask) != 0 || \ |
45 | sigaction(SIGNAL, &sa, nullptr) != 0) { \ |
46 | _exit(EXIT_FAILURE); \ |
47 | } else { \ |
48 | raise(SIGNAL); \ |
49 | } \ |
50 | } |
51 | |
52 | // signal(2) is really not portable. So use sigaction. |
53 | // http://man7.org/linux/man-pages/man2/signal.2.html |
54 | static inline void setSignalHandler( |
55 | int signal, |
56 | void (*handler)(int, siginfo_t*, void*), |
57 | struct sigaction* old_sa_ptr) { |
58 | struct sigaction sa {}; |
59 | sa.sa_sigaction = handler; |
60 | sa.sa_flags = SA_RESTART | SA_SIGINFO | SA_NOCLDSTOP | SA_NODEFER; |
61 | if (sigemptyset(&sa.sa_mask) != 0 || |
62 | sigaction(signal, &sa, old_sa_ptr) != 0) { |
63 | std::ostringstream oss; |
64 | oss << "An error occurred while setting handler for " << strsignal(signal) |
65 | << "." ; |
66 | throw std::runtime_error(oss.str()); |
67 | } |
68 | } |
69 | |
70 | SIGNAL_HANDLER( |
71 | SIGBUS, |
72 | handler_SIGBUS, |
73 | "ERROR: Unexpected bus error encountered in worker. " |
74 | "This might be caused by insufficient shared memory (shm).\n" ); |
75 | SIGNAL_HANDLER( |
76 | SIGSEGV, |
77 | handler_SIGSEGV, |
78 | "ERROR: Unexpected segmentation fault encountered in worker.\n" ); |
79 | SIGNAL_HANDLER( |
80 | SIGFPE, |
81 | handler_SIGFPE, |
82 | "ERROR: Unexpected floating-point exception encountered in worker.\n" ); |
83 | |
84 | // When an error happened in DataLoader methods and Python starts to exit, the |
85 | // error trace will keep the loader alive, and Python may kill the children |
86 | // processes first before deleting the loader object. Then the cleaning up |
87 | // methods in DataLoader.__del__ are not yet called, and SIGCHILD will print an |
88 | // error saying a worker is killed by SIGTERM. So we suppress SIGTERM from main |
89 | // loader process here to avoid this by _exit(EXIT_SUCCESS). Note that if we |
90 | // exit with nonzero code, the loader SIGCHLD handler may report RuntimeError |
91 | // again, and then it defeats the whole purpose. |
92 | static void handler_SIGTERM(int sig, siginfo_t* info, void* ctx) { |
93 | if (info->si_pid == getppid()) { |
94 | _exit(EXIT_SUCCESS); |
95 | } |
96 | struct sigaction sa {}; |
97 | sa.sa_handler = SIG_DFL; |
98 | sa.sa_flags = 0; |
99 | if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGTERM, &sa, nullptr) != 0) { |
100 | _exit(EXIT_FAILURE); |
101 | } else { |
102 | raise(SIGTERM); |
103 | } |
104 | } |
105 | |
106 | static PyObject* THPModule_setWorkerSignalHandlers( |
107 | PyObject* module, |
108 | PyObject* arg) { |
109 | HANDLE_TH_ERRORS |
110 | setSignalHandler(SIGBUS, &handler_SIGBUS, nullptr); |
111 | setSignalHandler(SIGSEGV, &handler_SIGSEGV, nullptr); |
112 | setSignalHandler(SIGTERM, &handler_SIGTERM, nullptr); |
113 | setSignalHandler(SIGFPE, &handler_SIGFPE, nullptr); |
114 | Py_RETURN_NONE; |
115 | END_HANDLE_TH_ERRORS |
116 | } |
117 | |
118 | static std::map<int64_t, std::set<pid_t>> worker_pids = {}; |
119 | |
120 | static PyObject* THPModule_errorIfAnyWorkerFails( |
121 | PyObject* module, |
122 | PyObject* noargs) { |
123 | HANDLE_TH_ERRORS |
124 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
125 | int error; |
126 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
127 | std::set<pid_t>* pid_set; |
128 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
129 | pid_t worker_pid; |
130 | siginfo_t infop; |
131 | |
132 | // Only check the pids we care about |
133 | for (auto& w : worker_pids) { |
134 | pid_set = &(w.second); |
135 | for (auto pid_it = pid_set->begin(); pid_it != pid_set->end(); ++pid_it) { |
136 | worker_pid = *pid_it; |
137 | // Use waitid rather than waitpid so that we can set NOWAIT, and that |
138 | // Python and other handlers can get whatever info they want about the |
139 | // child. |
140 | infop.si_pid = 0; |
141 | error = waitid(P_PID, worker_pid, &infop, WEXITED | WNOHANG | WNOWAIT); |
142 | // ignore errors and case with no waitable child |
143 | if (error < 0 || infop.si_pid == 0) |
144 | continue; |
145 | if (infop.si_code == CLD_EXITED && |
146 | infop.si_status != EXIT_SUCCESS) { // exit with error |
147 | std::ostringstream oss; |
148 | oss << "DataLoader worker (pid " << worker_pid << ") exited " |
149 | << "unexpectedly with exit code " << infop.si_status << ". " |
150 | << "Details are lost due to multiprocessing. Rerunning with " |
151 | << "num_workers=0 may give better error trace." ; |
152 | // This is necessary. Otherwise, the runtime error will kill the other |
153 | // workers, and trigger this again. |
154 | pid_set->clear(); |
155 | throw std::runtime_error(oss.str()); |
156 | } else if ( |
157 | infop.si_code == CLD_KILLED || |
158 | infop.si_code == CLD_DUMPED) { // killed by signal |
159 | std::ostringstream oss; |
160 | oss << "DataLoader worker (pid " << worker_pid << ") is killed " |
161 | << "by signal: " << strsignal(infop.si_status) << ". " ; |
162 | if (infop.si_status == SIGBUS) { |
163 | oss << "It is possible that dataloader's workers are out of shared memory. " |
164 | << "Please try to raise your shared memory limit." ; |
165 | } |
166 | // This is necessary. Otherwise, the runtime error will kill the other |
167 | // workers, and trigger this again. |
168 | pid_set->clear(); |
169 | throw std::runtime_error(oss.str()); |
170 | } |
171 | } |
172 | } |
173 | Py_RETURN_NONE; |
174 | END_HANDLE_TH_ERRORS |
175 | } |
176 | |
177 | // We don't want to exit on any SIGCHLD from any child. child_pids is a tuple |
178 | // of pids we are interested in. |
179 | static PyObject* THPModule_setWorkerPIDs(PyObject* module, PyObject* args) { |
180 | HANDLE_TH_ERRORS |
181 | if (PyTuple_GET_SIZE(args) != 2) { |
182 | throw TypeError("_set_worker_pids expects exactly 2 arguments." ); |
183 | } |
184 | int64_t key = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0)); |
185 | if (worker_pids.find(key) != worker_pids.end()) { |
186 | throw ValueError( |
187 | "_set_worker_pids should be called only once for each _BaseDataLoaderIter." ); |
188 | } |
189 | PyObject* child_pids = PyTuple_GET_ITEM(args, 1); |
190 | if (!PyTuple_Check(child_pids)) { |
191 | throw TypeError( |
192 | "_set_worker_pids expects a tuple for child_pids, but got %s." , |
193 | Py_TYPE(child_pids)->tp_name); |
194 | } |
195 | |
196 | std::set<pid_t> pids_set = {}; |
197 | auto size = PyTuple_GET_SIZE(child_pids); |
198 | for (const auto idx : c10::irange(size)) { |
199 | PyObject* obj = PyTuple_GET_ITEM(child_pids, idx); |
200 | pids_set.insert(static_cast<pid_t>(THPUtils_unpackLong(obj))); |
201 | } |
202 | |
203 | worker_pids[key] = pids_set; |
204 | |
205 | Py_RETURN_NONE; |
206 | END_HANDLE_TH_ERRORS |
207 | } |
208 | |
209 | static PyObject* THPModule_removeWorkerPIDs( |
210 | PyObject* module, |
211 | PyObject* loader_id) { |
212 | HANDLE_TH_ERRORS |
213 | |
214 | int64_t key = THPUtils_unpackLong(loader_id); |
215 | auto it = worker_pids.find(key); |
216 | if (it == worker_pids.end()) { |
217 | throw ValueError(fmt::format( |
218 | "Cannot find worker information for _BaseDataLoaderIter with id {}" , |
219 | key)); |
220 | } |
221 | worker_pids.erase(it); |
222 | |
223 | Py_RETURN_NONE; |
224 | END_HANDLE_TH_ERRORS |
225 | } |
226 | |
227 | #undef SIGNAL_HANDLER |
228 | |
229 | #else |
230 | // dummy implementations for windows |
231 | |
232 | static PyObject* THPModule_setWorkerSignalHandlers( |
233 | PyObject* module, |
234 | PyObject* _ignored) { |
235 | Py_RETURN_NONE; |
236 | } |
237 | |
238 | static PyObject* THPModule_setWorkerPIDs(PyObject* module, PyObject* _ignored) { |
239 | Py_RETURN_NONE; |
240 | } |
241 | |
242 | static PyObject* THPModule_removeWorkerPIDs( |
243 | PyObject* module, |
244 | PyObject* _ignored) { |
245 | Py_RETURN_NONE; |
246 | } |
247 | |
248 | static PyObject* THPModule_errorIfAnyWorkerFails( |
249 | PyObject* module, |
250 | PyObject* _ignored) { |
251 | Py_RETURN_NONE; |
252 | } |
253 | |
254 | #endif |
255 | |
256 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) |
257 | PyMethodDef DataLoaderMethods[] = { |
258 | {"_set_worker_signal_handlers" , |
259 | THPModule_setWorkerSignalHandlers, |
260 | METH_NOARGS, |
261 | nullptr}, |
262 | {"_set_worker_pids" , THPModule_setWorkerPIDs, METH_VARARGS, nullptr}, |
263 | {"_remove_worker_pids" , THPModule_removeWorkerPIDs, METH_O, nullptr}, |
264 | {"_error_if_any_worker_fails" , |
265 | THPModule_errorIfAnyWorkerFails, |
266 | METH_NOARGS, |
267 | nullptr}, |
268 | {nullptr, nullptr, 0, nullptr}}; |
269 | |