1#include <torch/csrc/Exceptions.h>
2#include <torch/csrc/python_headers.h>
3#include <torch/csrc/utils/object_ptr.h>
4#include <torch/csrc/utils/pybind.h>
5
6#include <stdexcept>
7
8#if defined(__linux__)
9#include <sys/prctl.h>
10#endif
11
12#define SYSASSERT(rv, ...) \
13 if ((rv) < 0) { \
14 throw std::system_error(errno, std::system_category(), ##__VA_ARGS__); \
15 }
16
17namespace torch {
18namespace multiprocessing {
19
20namespace {
21
22PyObject* multiprocessing_init(PyObject* _unused, PyObject* noargs) {
23 auto multiprocessing_module =
24 THPObjectPtr(PyImport_ImportModule("torch.multiprocessing"));
25 if (!multiprocessing_module) {
26 throw python_error();
27 }
28
29 auto module = py::handle(multiprocessing_module).cast<py::module>();
30
31 module.def("_prctl_pr_set_pdeathsig", [](int signal) {
32#if defined(__linux__)
33 auto rv = prctl(PR_SET_PDEATHSIG, signal);
34 SYSASSERT(rv, "prctl");
35#endif
36 });
37
38 Py_RETURN_TRUE;
39}
40
41} // namespace
42
43// multiprocessing methods on torch._C
44// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
45static PyMethodDef methods[] = {
46 {
47 "_multiprocessing_init",
48 multiprocessing_init,
49 METH_NOARGS,
50 nullptr,
51 },
52 {nullptr, nullptr, 0, nullptr},
53};
54
55PyMethodDef* python_functions() {
56 return methods;
57}
58
59} // namespace multiprocessing
60} // namespace torch
61