1#include <torch/csrc/DataLoader.h>
2
3// Together with `torch/utils/data/_utils/signal_handling.py`, the following
4// is an effort to do our best to provide some error message to users when a
5// worker dies due to error / critical signals.
6//
7// See NOTE [ Signal handling in multiprocessing data loading ] for more
8// details.
9
10// TODO: The following don't work on Windows. Specifically, sigaction, waitid
11// calls, and SIGCHLD handler. Currently, dummy implementations are provided
12// for Windows.
13
14#ifndef _WIN32
15
16#include <torch/csrc/Exceptions.h>
17#include <torch/csrc/utils/python_numbers.h>
18
19#include <c10/util/irange.h>
20#include <fmt/format.h>
21
22#include <sys/wait.h>
23#include <atomic>
24#include <csignal>
25#include <map>
26#include <set>
27#include <sstream>
28
29using namespace torch;
30
31// Critical signal handlers should be registered on worker processes before
32// doing work.
33// The handler will raise default handler so that the kill information will be
34// retrieved from main process.
35// Python handle is _set_worker_signal_handlers().
36#define SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG) \
37 static void HANDLER_NAME(int sig, siginfo_t* info, void* ctx) { \
38 auto _w = \
39 write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char)); \
40 (void)_w; \
41 struct sigaction sa {}; \
42 sa.sa_handler = SIG_DFL; \
43 sa.sa_flags = 0; \
44 if (sigemptyset(&sa.sa_mask) != 0 || \
45 sigaction(SIGNAL, &sa, nullptr) != 0) { \
46 _exit(EXIT_FAILURE); \
47 } else { \
48 raise(SIGNAL); \
49 } \
50 }
51
52// signal(2) is really not portable. So use sigaction.
53// http://man7.org/linux/man-pages/man2/signal.2.html
54static inline void setSignalHandler(
55 int signal,
56 void (*handler)(int, siginfo_t*, void*),
57 struct sigaction* old_sa_ptr) {
58 struct sigaction sa {};
59 sa.sa_sigaction = handler;
60 sa.sa_flags = SA_RESTART | SA_SIGINFO | SA_NOCLDSTOP | SA_NODEFER;
61 if (sigemptyset(&sa.sa_mask) != 0 ||
62 sigaction(signal, &sa, old_sa_ptr) != 0) {
63 std::ostringstream oss;
64 oss << "An error occurred while setting handler for " << strsignal(signal)
65 << ".";
66 throw std::runtime_error(oss.str());
67 }
68}
69
70SIGNAL_HANDLER(
71 SIGBUS,
72 handler_SIGBUS,
73 "ERROR: Unexpected bus error encountered in worker. "
74 "This might be caused by insufficient shared memory (shm).\n");
75SIGNAL_HANDLER(
76 SIGSEGV,
77 handler_SIGSEGV,
78 "ERROR: Unexpected segmentation fault encountered in worker.\n");
79SIGNAL_HANDLER(
80 SIGFPE,
81 handler_SIGFPE,
82 "ERROR: Unexpected floating-point exception encountered in worker.\n");
83
84// When an error happened in DataLoader methods and Python starts to exit, the
85// error trace will keep the loader alive, and Python may kill the children
86// processes first before deleting the loader object. Then the cleaning up
87// methods in DataLoader.__del__ are not yet called, and SIGCHILD will print an
88// error saying a worker is killed by SIGTERM. So we suppress SIGTERM from main
89// loader process here to avoid this by _exit(EXIT_SUCCESS). Note that if we
90// exit with nonzero code, the loader SIGCHLD handler may report RuntimeError
91// again, and then it defeats the whole purpose.
92static void handler_SIGTERM(int sig, siginfo_t* info, void* ctx) {
93 if (info->si_pid == getppid()) {
94 _exit(EXIT_SUCCESS);
95 }
96 struct sigaction sa {};
97 sa.sa_handler = SIG_DFL;
98 sa.sa_flags = 0;
99 if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGTERM, &sa, nullptr) != 0) {
100 _exit(EXIT_FAILURE);
101 } else {
102 raise(SIGTERM);
103 }
104}
105
106static PyObject* THPModule_setWorkerSignalHandlers(
107 PyObject* module,
108 PyObject* arg) {
109 HANDLE_TH_ERRORS
110 setSignalHandler(SIGBUS, &handler_SIGBUS, nullptr);
111 setSignalHandler(SIGSEGV, &handler_SIGSEGV, nullptr);
112 setSignalHandler(SIGTERM, &handler_SIGTERM, nullptr);
113 setSignalHandler(SIGFPE, &handler_SIGFPE, nullptr);
114 Py_RETURN_NONE;
115 END_HANDLE_TH_ERRORS
116}
117
118static std::map<int64_t, std::set<pid_t>> worker_pids = {};
119
120static PyObject* THPModule_errorIfAnyWorkerFails(
121 PyObject* module,
122 PyObject* noargs) {
123 HANDLE_TH_ERRORS
124 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
125 int error;
126 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
127 std::set<pid_t>* pid_set;
128 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
129 pid_t worker_pid;
130 siginfo_t infop;
131
132 // Only check the pids we care about
133 for (auto& w : worker_pids) {
134 pid_set = &(w.second);
135 for (auto pid_it = pid_set->begin(); pid_it != pid_set->end(); ++pid_it) {
136 worker_pid = *pid_it;
137 // Use waitid rather than waitpid so that we can set NOWAIT, and that
138 // Python and other handlers can get whatever info they want about the
139 // child.
140 infop.si_pid = 0;
141 error = waitid(P_PID, worker_pid, &infop, WEXITED | WNOHANG | WNOWAIT);
142 // ignore errors and case with no waitable child
143 if (error < 0 || infop.si_pid == 0)
144 continue;
145 if (infop.si_code == CLD_EXITED &&
146 infop.si_status != EXIT_SUCCESS) { // exit with error
147 std::ostringstream oss;
148 oss << "DataLoader worker (pid " << worker_pid << ") exited "
149 << "unexpectedly with exit code " << infop.si_status << ". "
150 << "Details are lost due to multiprocessing. Rerunning with "
151 << "num_workers=0 may give better error trace.";
152 // This is necessary. Otherwise, the runtime error will kill the other
153 // workers, and trigger this again.
154 pid_set->clear();
155 throw std::runtime_error(oss.str());
156 } else if (
157 infop.si_code == CLD_KILLED ||
158 infop.si_code == CLD_DUMPED) { // killed by signal
159 std::ostringstream oss;
160 oss << "DataLoader worker (pid " << worker_pid << ") is killed "
161 << "by signal: " << strsignal(infop.si_status) << ". ";
162 if (infop.si_status == SIGBUS) {
163 oss << "It is possible that dataloader's workers are out of shared memory. "
164 << "Please try to raise your shared memory limit.";
165 }
166 // This is necessary. Otherwise, the runtime error will kill the other
167 // workers, and trigger this again.
168 pid_set->clear();
169 throw std::runtime_error(oss.str());
170 }
171 }
172 }
173 Py_RETURN_NONE;
174 END_HANDLE_TH_ERRORS
175}
176
177// We don't want to exit on any SIGCHLD from any child. child_pids is a tuple
178// of pids we are interested in.
179static PyObject* THPModule_setWorkerPIDs(PyObject* module, PyObject* args) {
180 HANDLE_TH_ERRORS
181 if (PyTuple_GET_SIZE(args) != 2) {
182 throw TypeError("_set_worker_pids expects exactly 2 arguments.");
183 }
184 int64_t key = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0));
185 if (worker_pids.find(key) != worker_pids.end()) {
186 throw ValueError(
187 "_set_worker_pids should be called only once for each _BaseDataLoaderIter.");
188 }
189 PyObject* child_pids = PyTuple_GET_ITEM(args, 1);
190 if (!PyTuple_Check(child_pids)) {
191 throw TypeError(
192 "_set_worker_pids expects a tuple for child_pids, but got %s.",
193 Py_TYPE(child_pids)->tp_name);
194 }
195
196 std::set<pid_t> pids_set = {};
197 auto size = PyTuple_GET_SIZE(child_pids);
198 for (const auto idx : c10::irange(size)) {
199 PyObject* obj = PyTuple_GET_ITEM(child_pids, idx);
200 pids_set.insert(static_cast<pid_t>(THPUtils_unpackLong(obj)));
201 }
202
203 worker_pids[key] = pids_set;
204
205 Py_RETURN_NONE;
206 END_HANDLE_TH_ERRORS
207}
208
209static PyObject* THPModule_removeWorkerPIDs(
210 PyObject* module,
211 PyObject* loader_id) {
212 HANDLE_TH_ERRORS
213
214 int64_t key = THPUtils_unpackLong(loader_id);
215 auto it = worker_pids.find(key);
216 if (it == worker_pids.end()) {
217 throw ValueError(fmt::format(
218 "Cannot find worker information for _BaseDataLoaderIter with id {}",
219 key));
220 }
221 worker_pids.erase(it);
222
223 Py_RETURN_NONE;
224 END_HANDLE_TH_ERRORS
225}
226
227#undef SIGNAL_HANDLER
228
229#else
230// dummy implementations for windows
231
232static PyObject* THPModule_setWorkerSignalHandlers(
233 PyObject* module,
234 PyObject* _ignored) {
235 Py_RETURN_NONE;
236}
237
238static PyObject* THPModule_setWorkerPIDs(PyObject* module, PyObject* _ignored) {
239 Py_RETURN_NONE;
240}
241
242static PyObject* THPModule_removeWorkerPIDs(
243 PyObject* module,
244 PyObject* _ignored) {
245 Py_RETURN_NONE;
246}
247
248static PyObject* THPModule_errorIfAnyWorkerFails(
249 PyObject* module,
250 PyObject* _ignored) {
251 Py_RETURN_NONE;
252}
253
254#endif
255
256// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
257PyMethodDef DataLoaderMethods[] = {
258 {"_set_worker_signal_handlers",
259 THPModule_setWorkerSignalHandlers,
260 METH_NOARGS,
261 nullptr},
262 {"_set_worker_pids", THPModule_setWorkerPIDs, METH_VARARGS, nullptr},
263 {"_remove_worker_pids", THPModule_removeWorkerPIDs, METH_O, nullptr},
264 {"_error_if_any_worker_fails",
265 THPModule_errorIfAnyWorkerFails,
266 METH_NOARGS,
267 nullptr},
268 {nullptr, nullptr, 0, nullptr}};
269