1#include <torch/csrc/python_headers.h>
2
3#include <pybind11/chrono.h>
4
5#include <torch/csrc/jit/python/pybind_utils.h>
6#include <torch/csrc/utils/pybind.h>
7
8#include <ATen/cuda/CUDAGraph.h>
9
10// Cargo culted partially from csrc/distributed/c10d/init.cpp
11// and partially from csrc/cuda/Stream.cpp.
12// THCPStream_init is also declared at global scope.
13
14// Because THCPGraph_init is forward declared in the only consumer
15// (csrc/Module.cpp) I don't think we need a Graph.h.
16
17template <typename T>
18using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
19
20void THCPGraph_init(PyObject* module) {
21 // Pybind11 patch notes say "py::module_" is more up-to-date syntax,
22 // but CI linter and some builds prefer "module".
23 auto torch_C_m = py::handle(module).cast<py::module>();
24
25 torch_C_m.def("_graph_pool_handle", &::at::cuda::graph_pool_handle);
26
27 shared_ptr_class_<::at::cuda::CUDAGraph>(torch_C_m, "_CUDAGraph")
28 .def(py::init<>())
29 // I'm not sure this is the correct order of all the arguments. Pybind11
30 // docs aren't clear. But it works.
31 .def(
32 "capture_begin",
33 torch::wrap_pybind_function_no_gil(
34 &at::cuda::CUDAGraph::capture_begin),
35 py::arg("pool") = c10::cuda::MempoolId_t{0, 0})
36 .def(
37 "capture_end",
38 torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::capture_end))
39 .def(
40 "replay",
41 torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::replay))
42 .def(
43 "reset",
44 torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::reset))
45 .def(
46 "pool",
47 torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::pool))
48 .def(
49 "debug_dump",
50 torch::wrap_pybind_function_no_gil(
51 &::at::cuda::CUDAGraph::debug_dump))
52 .def(
53 "enable_debug_mode",
54 torch::wrap_pybind_function_no_gil(
55 &::at::cuda::CUDAGraph::enable_debug_mode))
56 .def(
57 "debug_dump",
58 torch::wrap_pybind_function_no_gil(
59 &::at::cuda::CUDAGraph::debug_dump),
60 py::arg("debug_path"));
61}
62