1 | #include <torch/csrc/Exceptions.h> |
2 | #include <torch/csrc/python_headers.h> |
3 | #include <torch/csrc/utils/object_ptr.h> |
4 | #include <torch/csrc/utils/pybind.h> |
5 | |
6 | #include <stdexcept> |
7 | |
8 | #if defined(__linux__) |
9 | #include <sys/prctl.h> |
10 | #endif |
11 | |
12 | #define SYSASSERT(rv, ...) \ |
13 | if ((rv) < 0) { \ |
14 | throw std::system_error(errno, std::system_category(), ##__VA_ARGS__); \ |
15 | } |
16 | |
17 | namespace torch { |
18 | namespace multiprocessing { |
19 | |
20 | namespace { |
21 | |
22 | PyObject* multiprocessing_init(PyObject* _unused, PyObject* noargs) { |
23 | auto multiprocessing_module = |
24 | THPObjectPtr(PyImport_ImportModule("torch.multiprocessing" )); |
25 | if (!multiprocessing_module) { |
26 | throw python_error(); |
27 | } |
28 | |
29 | auto module = py::handle(multiprocessing_module).cast<py::module>(); |
30 | |
31 | module.def("_prctl_pr_set_pdeathsig" , [](int signal) { |
32 | #if defined(__linux__) |
33 | auto rv = prctl(PR_SET_PDEATHSIG, signal); |
34 | SYSASSERT(rv, "prctl" ); |
35 | #endif |
36 | }); |
37 | |
38 | Py_RETURN_TRUE; |
39 | } |
40 | |
41 | } // namespace |
42 | |
43 | // multiprocessing methods on torch._C |
44 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
45 | static PyMethodDef methods[] = { |
46 | { |
47 | "_multiprocessing_init" , |
48 | multiprocessing_init, |
49 | METH_NOARGS, |
50 | nullptr, |
51 | }, |
52 | {nullptr, nullptr, 0, nullptr}, |
53 | }; |
54 | |
55 | PyMethodDef* python_functions() { |
56 | return methods; |
57 | } |
58 | |
59 | } // namespace multiprocessing |
60 | } // namespace torch |
61 | |