1 | #pragma once |
---|---|
2 | |
3 | #include <ATen/ATen.h> |
4 | #include <pybind11/pybind11.h> |
5 | #include <torch/csrc/Export.h> |
6 | #include <torch/csrc/autograd/python_variable.h> |
7 | #include <torch/csrc/autograd/saved_variable_hooks.h> |
8 | #include <torch/csrc/python_headers.h> |
9 | #include <torch/csrc/utils/pybind.h> |
10 | |
11 | namespace py = pybind11; |
12 | |
13 | namespace torch { |
14 | namespace autograd { |
15 | |
16 | struct PySavedVariableHooks : public SavedVariableHooks { |
17 | PySavedVariableHooks(py::function& pack_hook, py::function& unpack_hook); |
18 | void call_pack_hook(const at::Tensor& tensor) override; |
19 | at::Tensor call_unpack_hook() override; |
20 | ~PySavedVariableHooks() override; |
21 | |
22 | private: |
23 | PyObject* pack_hook_; |
24 | PyObject* unpack_hook_; |
25 | PyObject* data_ = nullptr; |
26 | }; |
27 | |
28 | struct PyDefaultSavedVariableHooks { |
29 | static void push_hooks(py::function& pack_hook, py::function& unpack_hook); |
30 | static void pop_hooks(); |
31 | static std::unique_ptr<SavedVariableHooks> get_hooks(); |
32 | }; |
33 | |
34 | } // namespace autograd |
35 | } // namespace torch |
36 |