1#ifdef _WIN32
2#include <wchar.h> // _wgetenv for nvtx
3#endif
4#include <nvToolsExt.h>
5#include <torch/csrc/utils/pybind.h>
6
7namespace torch {
8namespace cuda {
9namespace shared {
10
11void initNvtxBindings(PyObject* module) {
12 auto m = py::handle(module).cast<py::module>();
13
14 auto nvtx = m.def_submodule("_nvtx", "libNvToolsExt.so bindings");
15 nvtx.def("rangePushA", nvtxRangePushA);
16 nvtx.def("rangePop", nvtxRangePop);
17 nvtx.def("rangeStartA", nvtxRangeStartA);
18 nvtx.def("rangeEnd", nvtxRangeEnd);
19 nvtx.def("markA", nvtxMarkA);
20}
21
22} // namespace shared
23} // namespace cuda
24} // namespace torch
25