1#include <c10/util/irange.h>
2#include <torch/csrc/autograd/cpp_hook.h>
3#include <torch/csrc/autograd/custom_function.h>
4#include <torch/csrc/autograd/variable.h>
5
6#include <utility>
7
8namespace {
9using torch::autograd::Variable;
10void check_single_result(
11 const at::TensorBase& value,
12 const at::TensorBase& result,
13 std::string hook_name) {
14 if (!value.defined()) {
15 throw std::runtime_error(
16 "can't replace a empty gradient with a non-empty value");
17 }
18 torch::autograd::check_variable_result(value, result, std::move(hook_name));
19}
20} // namespace
21
22namespace torch {
23namespace autograd {
24
25// NOLINTNEXTLINE(modernize-pass-by-value)
26CppFunctionTensorPreHook::CppFunctionTensorPreHook(
27 const std::shared_ptr<hooks_list>& hooks,
28 int value_idx)
29 : hooks_(hooks), value_idx_(value_idx) {}
30
31variable_list CppFunctionTensorPreHook::operator()(
32 const variable_list& values) {
33 auto value = values[value_idx_];
34 for (const auto i : c10::irange(hooks_->size())) {
35 auto& hook = (*hooks_)[i];
36 if (!hook) {
37 // hook was removed
38 continue;
39 }
40 auto res = hook(value);
41 if (!res.defined()) {
42 // Don't change gradient
43 continue;
44 }
45 check_single_result(value, res, c10::to_string(i));
46 value = std::move(res);
47 }
48 variable_list results(values);
49 results[value_idx_] = value;
50 return results;
51}
52
53CppFunctionSingleTensorPreHook::CppFunctionSingleTensorPreHook(
54 std::function<at::TensorBase(const at::TensorBase&)> hook,
55 int value_idx)
56 : hook_(std::move(hook)), value_idx_(value_idx) {}
57
58variable_list CppFunctionSingleTensorPreHook::operator()(
59 const variable_list& values) {
60 const auto& value = values[value_idx_];
61 auto res = hook_(value);
62 TORCH_INTERNAL_ASSERT(
63 !res.defined(),
64 "CppFunctionSingleTensorPreHook currently only supports hooks that don't return");
65 variable_list results(values);
66 return results;
67}
68
69} // namespace autograd
70} // namespace torch
71