1 | #pragma once |
---|---|
2 | |
3 | #include <ATen/core/TensorBody.h> |
4 | #include <c10/util/Exception.h> |
5 | |
6 | namespace at { |
7 | class TORCH_API OptionalTensorRef { |
8 | public: |
9 | OptionalTensorRef() = default; |
10 | |
11 | ~OptionalTensorRef() { |
12 | ref_.unsafeReleaseTensorImpl(); |
13 | } |
14 | |
15 | OptionalTensorRef(const TensorBase& src) |
16 | : ref_(Tensor::unsafe_borrow_t{}, src) { |
17 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src.defined()); |
18 | } |
19 | |
20 | OptionalTensorRef(const OptionalTensorRef& rhs) |
21 | : ref_(Tensor::unsafe_borrow_t{}, rhs.ref_) {} |
22 | |
23 | OptionalTensorRef& operator=(OptionalTensorRef rhs) { |
24 | std::swap(ref_, rhs.ref_); |
25 | return *this; |
26 | } |
27 | |
28 | bool has_value() const { |
29 | return ref_.defined(); |
30 | } |
31 | |
32 | const Tensor& getTensorRef() const & { |
33 | return ref_; |
34 | } |
35 | |
36 | const Tensor& operator*() const & { |
37 | return ref_; |
38 | } |
39 | |
40 | const Tensor* operator->() const & { |
41 | return &ref_; |
42 | } |
43 | |
44 | operator bool() const { |
45 | return ref_.defined(); |
46 | } |
47 | |
48 | private: |
49 | Tensor ref_; |
50 | }; |
51 | |
52 | template <typename T> |
53 | auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_void_t<T> { |
54 | // Return the grad argument in case of a hook with void return type to have an |
55 | // std::function with Tensor return type |
56 | static_assert(std::is_same<decltype(hook(Tensor())), void>::value, |
57 | "Expected hook to return void"); |
58 | return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) { |
59 | OptionalTensorRef grad(grad_base); |
60 | fn(*grad); |
61 | return Tensor(); |
62 | }); |
63 | } |
64 | |
65 | template <typename T> |
66 | auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_var_t<T> { |
67 | return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) { |
68 | OptionalTensorRef grad(grad_base); |
69 | Tensor ret = fn(*grad); |
70 | return TensorBase(std::move(ret)); |
71 | }); |
72 | } |
73 | |
74 | } // namespace at |
75 |