1 | #include <c10/util/irange.h> |
2 | #include <torch/csrc/autograd/cpp_hook.h> |
3 | #include <torch/csrc/autograd/custom_function.h> |
4 | #include <torch/csrc/autograd/variable.h> |
5 | |
6 | #include <utility> |
7 | |
8 | namespace { |
9 | using torch::autograd::Variable; |
10 | void check_single_result( |
11 | const at::TensorBase& value, |
12 | const at::TensorBase& result, |
13 | std::string hook_name) { |
14 | if (!value.defined()) { |
15 | throw std::runtime_error( |
16 | "can't replace a empty gradient with a non-empty value" ); |
17 | } |
18 | torch::autograd::check_variable_result(value, result, std::move(hook_name)); |
19 | } |
20 | } // namespace |
21 | |
22 | namespace torch { |
23 | namespace autograd { |
24 | |
25 | // NOLINTNEXTLINE(modernize-pass-by-value) |
26 | CppFunctionTensorPreHook::CppFunctionTensorPreHook( |
27 | const std::shared_ptr<hooks_list>& hooks, |
28 | int value_idx) |
29 | : hooks_(hooks), value_idx_(value_idx) {} |
30 | |
31 | variable_list CppFunctionTensorPreHook::operator()( |
32 | const variable_list& values) { |
33 | auto value = values[value_idx_]; |
34 | for (const auto i : c10::irange(hooks_->size())) { |
35 | auto& hook = (*hooks_)[i]; |
36 | if (!hook) { |
37 | // hook was removed |
38 | continue; |
39 | } |
40 | auto res = hook(value); |
41 | if (!res.defined()) { |
42 | // Don't change gradient |
43 | continue; |
44 | } |
45 | check_single_result(value, res, c10::to_string(i)); |
46 | value = std::move(res); |
47 | } |
48 | variable_list results(values); |
49 | results[value_idx_] = value; |
50 | return results; |
51 | } |
52 | |
53 | CppFunctionSingleTensorPreHook::CppFunctionSingleTensorPreHook( |
54 | std::function<at::TensorBase(const at::TensorBase&)> hook, |
55 | int value_idx) |
56 | : hook_(std::move(hook)), value_idx_(value_idx) {} |
57 | |
58 | variable_list CppFunctionSingleTensorPreHook::operator()( |
59 | const variable_list& values) { |
60 | const auto& value = values[value_idx_]; |
61 | auto res = hook_(value); |
62 | TORCH_INTERNAL_ASSERT( |
63 | !res.defined(), |
64 | "CppFunctionSingleTensorPreHook currently only supports hooks that don't return" ); |
65 | variable_list results(values); |
66 | return results; |
67 | } |
68 | |
69 | } // namespace autograd |
70 | } // namespace torch |
71 | |