1 | #pragma once |
---|---|
2 | |
3 | #include <torch/csrc/autograd/function_hook.h> |
4 | #include <torch/csrc/python_headers.h> |
5 | #include <torch/csrc/utils/object_ptr.h> |
6 | |
7 | namespace torch { |
8 | namespace autograd { |
9 | |
10 | struct PyFunctionTensorPreHook : public FunctionPreHook { |
11 | PyFunctionTensorPreHook(PyObject* dict, int value_idx); |
12 | ~PyFunctionTensorPreHook() override; |
13 | variable_list operator()(const variable_list& values) override; |
14 | PyObject* dict; |
15 | int value_idx; |
16 | }; |
17 | |
18 | struct PyFunctionPreHook : public FunctionPreHook { |
19 | PyFunctionPreHook(PyObject* dict); |
20 | ~PyFunctionPreHook() override; |
21 | variable_list operator()(const variable_list& values) override; |
22 | PyObject* dict; |
23 | }; |
24 | |
25 | struct PyFunctionPostHook : public FunctionPostHook { |
26 | PyFunctionPostHook(PyObject* dict); |
27 | ~PyFunctionPostHook() override; |
28 | variable_list operator()( |
29 | const variable_list& outputs, |
30 | const variable_list& inputs) override; |
31 | PyObject* dict; |
32 | }; |
33 | |
34 | } // namespace autograd |
35 | } // namespace torch |
36 |