1#pragma once
2
3#include <memory>
4#include <type_traits>
5
6#include <c10/util/Exception.h>
7#include <c10/util/MaybeOwned.h>
8
9namespace c10 {
10
11// Compatibility wrapper around a raw pointer so that existing code
12// written to deal with a shared_ptr can keep working.
13template <typename T>
14class SingletonTypePtr {
15 public:
16 /* implicit */ SingletonTypePtr(T* p) : repr_(p) {}
17
18 // We need this to satisfy Pybind11, but it shouldn't be hit.
19 explicit SingletonTypePtr(std::shared_ptr<T>) { TORCH_CHECK(false); }
20
21 using element_type = typename std::shared_ptr<T>::element_type;
22
23 template <typename U = T, std::enable_if_t<!std::is_same<std::remove_const_t<U>, void>::value, bool> = true>
24 T& operator*() const {
25 return *repr_;
26 }
27
28 T* get() const {
29 return repr_;
30 }
31
32 T* operator->() const {
33 return repr_;
34 }
35
36 operator bool() const {
37 return repr_ != nullptr;
38 }
39
40 private:
41 T* repr_{nullptr};
42};
43
44template <typename T, typename U>
45bool operator==(SingletonTypePtr<T> lhs, SingletonTypePtr<U> rhs) {
46 return (void*)lhs.get() == (void*)rhs.get();
47}
48
49template <typename T, typename U>
50bool operator!=(SingletonTypePtr<T> lhs, SingletonTypePtr<U> rhs) {
51 return !(lhs == rhs);
52}
53
54} // namespace c10
55