1#pragma once
2
3#include <Python.h>
4
5#include <torch/csrc/profiler/collection.h>
6#include <torch/csrc/profiler/python/pybind.h>
7
8namespace pybind11 {
9namespace detail {
10using torch::profiler::impl::TensorID;
11
12#define STRONG_POINTER_TYPE_CASTER(T) \
13 template <> \
14 struct type_caster<T> : public strong_pointer_type_caster<T> {};
15
16STRONG_POINTER_TYPE_CASTER(torch::profiler::impl::StorageImplData);
17STRONG_POINTER_TYPE_CASTER(torch::profiler::impl::AllocationID);
18STRONG_POINTER_TYPE_CASTER(torch::profiler::impl::TensorImplAddress);
19STRONG_POINTER_TYPE_CASTER(torch::profiler::impl::PyModuleSelf);
20STRONG_POINTER_TYPE_CASTER(torch::profiler::impl::PyModuleCls);
21STRONG_POINTER_TYPE_CASTER(torch::profiler::impl::PyOptimizerSelf);
22#undef STRONG_POINTER_TYPE_CASTER
23
24template <>
25struct type_caster<TensorID> : public strong_uint_type_caster<TensorID> {};
26} // namespace detail
27} // namespace pybind11
28
29namespace torch {
30namespace profiler {
31
32void initPythonBindings(PyObject* module);
33
34} // namespace profiler
35} // namespace torch
36