1 | #pragma once |
---|---|
2 | #include <torch/csrc/autograd/function_hook.h> |
3 | #include <functional> |
4 | #include <memory> |
5 | |
6 | namespace torch { |
7 | namespace autograd { |
8 | |
9 | using hooks_list = |
10 | std::vector<std::function<at::TensorBase(const at::TensorBase&)>>; |
11 | |
12 | struct CppFunctionTensorPreHook : public FunctionPreHook { |
13 | CppFunctionTensorPreHook( |
14 | const std::shared_ptr<hooks_list>& hooks, |
15 | int value_idx); |
16 | variable_list operator()(const variable_list& values) override; |
17 | |
18 | std::shared_ptr<hooks_list> hooks_; |
19 | int value_idx_; |
20 | }; |
21 | |
22 | struct CppFunctionSingleTensorPreHook : public FunctionPreHook { |
23 | CppFunctionSingleTensorPreHook( |
24 | std::function<at::TensorBase(const at::TensorBase&)> hook, |
25 | int value_idx); |
26 | variable_list operator()(const variable_list& values) override; |
27 | |
28 | std::function<at::TensorBase(const at::TensorBase&)> hook_; |
29 | int value_idx_; |
30 | }; |
31 | |
32 | } // namespace autograd |
33 | } // namespace torch |
34 |