1 | #ifdef _WIN32 |
2 | #include <wchar.h> // _wgetenv for nvtx |
3 | #endif |
4 | #include <nvToolsExt.h> |
5 | #include <torch/csrc/utils/pybind.h> |
6 | |
7 | namespace torch { |
8 | namespace cuda { |
9 | namespace shared { |
10 | |
11 | void initNvtxBindings(PyObject* module) { |
12 | auto m = py::handle(module).cast<py::module>(); |
13 | |
14 | auto nvtx = m.def_submodule("_nvtx" , "libNvToolsExt.so bindings" ); |
15 | nvtx.def("rangePushA" , nvtxRangePushA); |
16 | nvtx.def("rangePop" , nvtxRangePop); |
17 | nvtx.def("rangeStartA" , nvtxRangeStartA); |
18 | nvtx.def("rangeEnd" , nvtxRangeEnd); |
19 | nvtx.def("markA" , nvtxMarkA); |
20 | } |
21 | |
22 | } // namespace shared |
23 | } // namespace cuda |
24 | } // namespace torch |
25 | |