1 | #pragma once |
---|---|
2 | |
3 | #include <ATen/Tensor.h> |
4 | #include <torch/csrc/Export.h> |
5 | #include <vector> |
6 | |
7 | // A hook that's called on gradients |
8 | |
9 | namespace torch { |
10 | namespace autograd { |
11 | |
12 | using Variable = at::Tensor; |
13 | using variable_list = std::vector<Variable>; |
14 | |
15 | struct TORCH_API FunctionPreHook { |
16 | virtual ~FunctionPreHook() = default; |
17 | virtual variable_list operator()(const variable_list& grads) = 0; |
18 | }; |
19 | |
20 | struct TORCH_API FunctionPostHook { |
21 | virtual ~FunctionPostHook() = default; |
22 | virtual variable_list operator()( |
23 | const variable_list& outputs /* grad_inputs */, |
24 | const variable_list& inputs /* grad_outputs */) = 0; |
25 | }; |
26 | |
27 | } // namespace autograd |
28 | } // namespace torch |
29 |