1 | #pragma once |
2 | |
3 | #include <memory> |
4 | #include <type_traits> |
5 | |
6 | #include <c10/util/Exception.h> |
7 | #include <c10/util/MaybeOwned.h> |
8 | |
9 | namespace c10 { |
10 | |
11 | // Compatibility wrapper around a raw pointer so that existing code |
12 | // written to deal with a shared_ptr can keep working. |
13 | template <typename T> |
14 | class SingletonTypePtr { |
15 | public: |
16 | /* implicit */ SingletonTypePtr(T* p) : repr_(p) {} |
17 | |
18 | // We need this to satisfy Pybind11, but it shouldn't be hit. |
19 | explicit SingletonTypePtr(std::shared_ptr<T>) { TORCH_CHECK(false); } |
20 | |
21 | using element_type = typename std::shared_ptr<T>::element_type; |
22 | |
23 | template <typename U = T, std::enable_if_t<!std::is_same<std::remove_const_t<U>, void>::value, bool> = true> |
24 | T& operator*() const { |
25 | return *repr_; |
26 | } |
27 | |
28 | T* get() const { |
29 | return repr_; |
30 | } |
31 | |
32 | T* operator->() const { |
33 | return repr_; |
34 | } |
35 | |
36 | operator bool() const { |
37 | return repr_ != nullptr; |
38 | } |
39 | |
40 | private: |
41 | T* repr_{nullptr}; |
42 | }; |
43 | |
44 | template <typename T, typename U> |
45 | bool operator==(SingletonTypePtr<T> lhs, SingletonTypePtr<U> rhs) { |
46 | return (void*)lhs.get() == (void*)rhs.get(); |
47 | } |
48 | |
49 | template <typename T, typename U> |
50 | bool operator!=(SingletonTypePtr<T> lhs, SingletonTypePtr<U> rhs) { |
51 | return !(lhs == rhs); |
52 | } |
53 | |
54 | } // namespace c10 |
55 | |