1#pragma once
2
3#include <ATen/ATen.h>
4#include <pybind11/pybind11.h>
5#include <torch/csrc/Export.h>
6#include <torch/csrc/autograd/python_variable.h>
7#include <torch/csrc/autograd/saved_variable_hooks.h>
8#include <torch/csrc/python_headers.h>
9#include <torch/csrc/utils/pybind.h>
10
11namespace py = pybind11;
12
13namespace torch {
14namespace autograd {
15
16struct PySavedVariableHooks : public SavedVariableHooks {
17 PySavedVariableHooks(py::function& pack_hook, py::function& unpack_hook);
18 void call_pack_hook(const at::Tensor& tensor) override;
19 at::Tensor call_unpack_hook() override;
20 ~PySavedVariableHooks() override;
21
22 private:
23 PyObject* pack_hook_;
24 PyObject* unpack_hook_;
25 PyObject* data_ = nullptr;
26};
27
28struct PyDefaultSavedVariableHooks {
29 static void push_hooks(py::function& pack_hook, py::function& unpack_hook);
30 static void pop_hooks();
31 static std::unique_ptr<SavedVariableHooks> get_hooks();
32};
33
34} // namespace autograd
35} // namespace torch
36