1#pragma once
2#include <torch/csrc/autograd/function_hook.h>
3#include <functional>
4#include <memory>
5
6namespace torch {
7namespace autograd {
8
9using hooks_list =
10 std::vector<std::function<at::TensorBase(const at::TensorBase&)>>;
11
12struct CppFunctionTensorPreHook : public FunctionPreHook {
13 CppFunctionTensorPreHook(
14 const std::shared_ptr<hooks_list>& hooks,
15 int value_idx);
16 variable_list operator()(const variable_list& values) override;
17
18 std::shared_ptr<hooks_list> hooks_;
19 int value_idx_;
20};
21
22struct CppFunctionSingleTensorPreHook : public FunctionPreHook {
23 CppFunctionSingleTensorPreHook(
24 std::function<at::TensorBase(const at::TensorBase&)> hook,
25 int value_idx);
26 variable_list operator()(const variable_list& values) override;
27
28 std::function<at::TensorBase(const at::TensorBase&)> hook_;
29 int value_idx_;
30};
31
32} // namespace autograd
33} // namespace torch
34