1 | #include <torch/csrc/python_headers.h> |
2 | |
3 | #include <pybind11/chrono.h> |
4 | |
5 | #include <torch/csrc/jit/python/pybind_utils.h> |
6 | #include <torch/csrc/utils/pybind.h> |
7 | |
8 | #include <ATen/cuda/CUDAGraph.h> |
9 | |
10 | // Cargo culted partially from csrc/distributed/c10d/init.cpp |
11 | // and partially from csrc/cuda/Stream.cpp. |
12 | // THCPStream_init is also declared at global scope. |
13 | |
14 | // Because THCPGraph_init is forward declared in the only consumer |
15 | // (csrc/Module.cpp) I don't think we need a Graph.h. |
16 | |
17 | template <typename T> |
18 | using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>; |
19 | |
20 | void THCPGraph_init(PyObject* module) { |
21 | // Pybind11 patch notes say "py::module_" is more up-to-date syntax, |
22 | // but CI linter and some builds prefer "module". |
23 | auto torch_C_m = py::handle(module).cast<py::module>(); |
24 | |
25 | torch_C_m.def("_graph_pool_handle" , &::at::cuda::graph_pool_handle); |
26 | |
27 | shared_ptr_class_<::at::cuda::CUDAGraph>(torch_C_m, "_CUDAGraph" ) |
28 | .def(py::init<>()) |
29 | // I'm not sure this is the correct order of all the arguments. Pybind11 |
30 | // docs aren't clear. But it works. |
31 | .def( |
32 | "capture_begin" , |
33 | torch::wrap_pybind_function_no_gil( |
34 | &at::cuda::CUDAGraph::capture_begin), |
35 | py::arg("pool" ) = c10::cuda::MempoolId_t{0, 0}) |
36 | .def( |
37 | "capture_end" , |
38 | torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::capture_end)) |
39 | .def( |
40 | "replay" , |
41 | torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::replay)) |
42 | .def( |
43 | "reset" , |
44 | torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::reset)) |
45 | .def( |
46 | "pool" , |
47 | torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::pool)) |
48 | .def( |
49 | "debug_dump" , |
50 | torch::wrap_pybind_function_no_gil( |
51 | &::at::cuda::CUDAGraph::debug_dump)) |
52 | .def( |
53 | "enable_debug_mode" , |
54 | torch::wrap_pybind_function_no_gil( |
55 | &::at::cuda::CUDAGraph::enable_debug_mode)) |
56 | .def( |
57 | "debug_dump" , |
58 | torch::wrap_pybind_function_no_gil( |
59 | &::at::cuda::CUDAGraph::debug_dump), |
60 | py::arg("debug_path" )); |
61 | } |
62 | |