1 | #pragma once |
---|---|
2 | |
3 | #include <Python.h> |
4 | |
5 | #include <torch/csrc/profiler/collection.h> |
6 | #include <torch/csrc/profiler/python/pybind.h> |
7 | |
8 | namespace pybind11 { |
9 | namespace detail { |
10 | using torch::profiler::impl::TensorID; |
11 | |
12 | #define STRONG_POINTER_TYPE_CASTER(T) \ |
13 | template <> \ |
14 | struct type_caster<T> : public strong_pointer_type_caster<T> {}; |
15 | |
16 | STRONG_POINTER_TYPE_CASTER(torch::profiler::impl::StorageImplData); |
17 | STRONG_POINTER_TYPE_CASTER(torch::profiler::impl::AllocationID); |
18 | STRONG_POINTER_TYPE_CASTER(torch::profiler::impl::TensorImplAddress); |
19 | STRONG_POINTER_TYPE_CASTER(torch::profiler::impl::PyModuleSelf); |
20 | STRONG_POINTER_TYPE_CASTER(torch::profiler::impl::PyModuleCls); |
21 | STRONG_POINTER_TYPE_CASTER(torch::profiler::impl::PyOptimizerSelf); |
22 | #undef STRONG_POINTER_TYPE_CASTER |
23 | |
24 | template <> |
25 | struct type_caster<TensorID> : public strong_uint_type_caster<TensorID> {}; |
26 | } // namespace detail |
27 | } // namespace pybind11 |
28 | |
29 | namespace torch { |
30 | namespace profiler { |
31 | |
32 | void initPythonBindings(PyObject* module); |
33 | |
34 | } // namespace profiler |
35 | } // namespace torch |
36 |