1#pragma once
2
3#include <ATen/Tensor.h>
4#include <torch/csrc/Export.h>
5#include <vector>
6
7// A hook that's called on gradients
8
9namespace torch {
10namespace autograd {
11
12using Variable = at::Tensor;
13using variable_list = std::vector<Variable>;
14
15struct TORCH_API FunctionPreHook {
16 virtual ~FunctionPreHook() = default;
17 virtual variable_list operator()(const variable_list& grads) = 0;
18};
19
20struct TORCH_API FunctionPostHook {
21 virtual ~FunctionPostHook() = default;
22 virtual variable_list operator()(
23 const variable_list& outputs /* grad_inputs */,
24 const variable_list& inputs /* grad_outputs */) = 0;
25};
26
27} // namespace autograd
28} // namespace torch
29