1#pragma once
2
3#include <ATen/core/TensorBody.h>
4#include <c10/util/Exception.h>
5
6namespace at {
7class TORCH_API OptionalTensorRef {
8 public:
9 OptionalTensorRef() = default;
10
11 ~OptionalTensorRef() {
12 ref_.unsafeReleaseTensorImpl();
13 }
14
15 OptionalTensorRef(const TensorBase& src)
16 : ref_(Tensor::unsafe_borrow_t{}, src) {
17 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src.defined());
18 }
19
20 OptionalTensorRef(const OptionalTensorRef& rhs)
21 : ref_(Tensor::unsafe_borrow_t{}, rhs.ref_) {}
22
23 OptionalTensorRef& operator=(OptionalTensorRef rhs) {
24 std::swap(ref_, rhs.ref_);
25 return *this;
26 }
27
28 bool has_value() const {
29 return ref_.defined();
30 }
31
32 const Tensor& getTensorRef() const & {
33 return ref_;
34 }
35
36 const Tensor& operator*() const & {
37 return ref_;
38 }
39
40 const Tensor* operator->() const & {
41 return &ref_;
42 }
43
44 operator bool() const {
45 return ref_.defined();
46 }
47
48 private:
49 Tensor ref_;
50};
51
52template <typename T>
53auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_void_t<T> {
54 // Return the grad argument in case of a hook with void return type to have an
55 // std::function with Tensor return type
56 static_assert(std::is_same<decltype(hook(Tensor())), void>::value,
57 "Expected hook to return void");
58 return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
59 OptionalTensorRef grad(grad_base);
60 fn(*grad);
61 return Tensor();
62 });
63}
64
65template <typename T>
66auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_var_t<T> {
67 return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
68 OptionalTensorRef grad(grad_base);
69 Tensor ret = fn(*grad);
70 return TensorBase(std::move(ret));
71 });
72}
73
74} // namespace at
75