1#pragma once
2
3#include <torch/csrc/autograd/function_hook.h>
4#include <torch/csrc/python_headers.h>
5#include <torch/csrc/utils/object_ptr.h>
6
7namespace torch {
8namespace autograd {
9
10struct PyFunctionTensorPreHook : public FunctionPreHook {
11 PyFunctionTensorPreHook(PyObject* dict, int value_idx);
12 ~PyFunctionTensorPreHook() override;
13 variable_list operator()(const variable_list& values) override;
14 PyObject* dict;
15 int value_idx;
16};
17
18struct PyFunctionPreHook : public FunctionPreHook {
19 PyFunctionPreHook(PyObject* dict);
20 ~PyFunctionPreHook() override;
21 variable_list operator()(const variable_list& values) override;
22 PyObject* dict;
23};
24
25struct PyFunctionPostHook : public FunctionPostHook {
26 PyFunctionPostHook(PyObject* dict);
27 ~PyFunctionPostHook() override;
28 variable_list operator()(
29 const variable_list& outputs,
30 const variable_list& inputs) override;
31 PyObject* dict;
32};
33
34} // namespace autograd
35} // namespace torch
36