1 | // This file defines OptionalArrayRef<T>, a class that has almost the same |
2 | // exact functionality as c10::optional<ArrayRef<T>>, except that its |
3 | // converting constructor fixes a dangling pointer issue. |
4 | // |
5 | // The implicit converting constructor of both c10::optional<ArrayRef<T>> and |
6 | // std::optional<ArrayRef<T>> can cause the underlying ArrayRef<T> to store |
7 | // a dangling pointer. OptionalArrayRef<T> prevents this by wrapping |
8 | // a c10::optional<ArrayRef<T>> and fixing the constructor implementation. |
9 | // |
10 | // See https://github.com/pytorch/pytorch/issues/63645 for more on this. |
11 | |
12 | #pragma once |
13 | |
14 | #include <c10/util/ArrayRef.h> |
15 | #include <c10/util/Optional.h> |
16 | |
17 | namespace c10 { |
18 | |
19 | template <typename T> |
20 | class OptionalArrayRef final { |
21 | public: |
22 | // Constructors |
23 | |
24 | constexpr OptionalArrayRef() noexcept = default; |
25 | |
26 | constexpr OptionalArrayRef(nullopt_t) noexcept {} |
27 | |
28 | OptionalArrayRef(const OptionalArrayRef& other) = default; |
29 | |
30 | OptionalArrayRef(OptionalArrayRef&& other) = default; |
31 | |
32 | constexpr OptionalArrayRef(const optional<ArrayRef<T>>& other) noexcept |
33 | : wrapped_opt_array_ref(other) {} |
34 | |
35 | constexpr OptionalArrayRef(optional<ArrayRef<T>>&& other) noexcept |
36 | : wrapped_opt_array_ref(other) {} |
37 | |
38 | constexpr OptionalArrayRef(const T& value) noexcept |
39 | : wrapped_opt_array_ref(value) {} |
40 | |
41 | template < |
42 | typename U = ArrayRef<T>, |
43 | std::enable_if_t< |
44 | !std::is_same<std::decay_t<U>, OptionalArrayRef>::value && |
45 | !std::is_same<std::decay_t<U>, in_place_t>::value && |
46 | std::is_constructible<ArrayRef<T>, U&&>::value && |
47 | std::is_convertible<U&&, ArrayRef<T>>::value && |
48 | !std::is_convertible<U&&, T>::value, |
49 | bool> = false> |
50 | constexpr OptionalArrayRef(U&& value) noexcept( |
51 | std::is_nothrow_constructible<ArrayRef<T>, U&&>::value) |
52 | : wrapped_opt_array_ref(value) {} |
53 | |
54 | template < |
55 | typename U = ArrayRef<T>, |
56 | std::enable_if_t< |
57 | !std::is_same<std::decay_t<U>, OptionalArrayRef>::value && |
58 | !std::is_same<std::decay_t<U>, in_place_t>::value && |
59 | std::is_constructible<ArrayRef<T>, U&&>::value && |
60 | !std::is_convertible<U&&, ArrayRef<T>>::value, |
61 | bool> = false> |
62 | constexpr explicit OptionalArrayRef(U&& value) noexcept( |
63 | std::is_nothrow_constructible<ArrayRef<T>, U&&>::value) |
64 | : wrapped_opt_array_ref(value) {} |
65 | |
66 | template <typename... Args> |
67 | constexpr explicit OptionalArrayRef(in_place_t ip, Args&&... args) noexcept |
68 | : wrapped_opt_array_ref(ip, args...) {} |
69 | |
70 | template <typename U, typename... Args> |
71 | constexpr explicit OptionalArrayRef( |
72 | in_place_t ip, |
73 | std::initializer_list<U> il, |
74 | Args&&... args) |
75 | : wrapped_opt_array_ref(ip, il, args...) {} |
76 | |
77 | constexpr OptionalArrayRef(const std::initializer_list<T>& Vec) |
78 | : wrapped_opt_array_ref(ArrayRef<T>(Vec)) {} |
79 | |
80 | // Destructor |
81 | |
82 | ~OptionalArrayRef() = default; |
83 | |
84 | // Assignment |
85 | |
86 | constexpr OptionalArrayRef& operator=(nullopt_t) noexcept { |
87 | wrapped_opt_array_ref = c10::nullopt; |
88 | return *this; |
89 | } |
90 | |
91 | OptionalArrayRef& operator=(const OptionalArrayRef& other) = default; |
92 | |
93 | OptionalArrayRef& operator=(OptionalArrayRef&& other) = default; |
94 | |
95 | constexpr OptionalArrayRef& operator=( |
96 | const optional<ArrayRef<T>>& other) noexcept { |
97 | wrapped_opt_array_ref = other; |
98 | return *this; |
99 | } |
100 | |
101 | constexpr OptionalArrayRef& operator=( |
102 | optional<ArrayRef<T>>&& other) noexcept { |
103 | wrapped_opt_array_ref = other; |
104 | return *this; |
105 | } |
106 | |
107 | template <typename U = ArrayRef<T>> |
108 | constexpr std::enable_if_t< |
109 | !std::is_same<std::decay_t<U>, OptionalArrayRef>::value && |
110 | std::is_constructible<ArrayRef<T>, U&&>::value && |
111 | std::is_assignable<ArrayRef<T>&, U&&>::value, |
112 | OptionalArrayRef&> |
113 | operator=(U&& value) noexcept( |
114 | std::is_nothrow_constructible<ArrayRef<T>, U&&>::value&& |
115 | std::is_nothrow_assignable<ArrayRef<T>&, U&&>::value) { |
116 | wrapped_opt_array_ref = value; |
117 | return *this; |
118 | } |
119 | |
120 | // Observers |
121 | |
122 | constexpr ArrayRef<T>* operator->() noexcept { |
123 | return &wrapped_opt_array_ref.value(); |
124 | } |
125 | |
126 | constexpr const ArrayRef<T>* operator->() const noexcept { |
127 | return &wrapped_opt_array_ref.value(); |
128 | } |
129 | |
130 | constexpr ArrayRef<T>& operator*() & noexcept { |
131 | return wrapped_opt_array_ref.value(); |
132 | } |
133 | |
134 | constexpr const ArrayRef<T>& operator*() const& noexcept { |
135 | return wrapped_opt_array_ref.value(); |
136 | } |
137 | |
138 | constexpr ArrayRef<T>&& operator*() && noexcept { |
139 | return std::move(wrapped_opt_array_ref.value()); |
140 | } |
141 | |
142 | constexpr const ArrayRef<T>&& operator*() const&& noexcept { |
143 | return std::move(wrapped_opt_array_ref.value()); |
144 | } |
145 | |
146 | constexpr explicit operator bool() const noexcept { |
147 | return wrapped_opt_array_ref.has_value(); |
148 | } |
149 | |
150 | constexpr bool has_value() const noexcept { |
151 | return wrapped_opt_array_ref.has_value(); |
152 | } |
153 | |
154 | constexpr ArrayRef<T>& value() & { |
155 | return wrapped_opt_array_ref.value(); |
156 | } |
157 | |
158 | constexpr const ArrayRef<T>& value() const& { |
159 | return wrapped_opt_array_ref.value(); |
160 | } |
161 | |
162 | constexpr ArrayRef<T>&& value() && { |
163 | return std::move(wrapped_opt_array_ref.value()); |
164 | } |
165 | |
166 | constexpr const ArrayRef<T>&& value() const&& { |
167 | return std::move(wrapped_opt_array_ref.value()); |
168 | } |
169 | |
170 | template <typename U> |
171 | constexpr std:: |
172 | enable_if_t<std::is_convertible<U&&, ArrayRef<T>>::value, ArrayRef<T>> |
173 | value_or(U&& default_value) const& { |
174 | return wrapped_opt_array_ref.value_or(default_value); |
175 | } |
176 | |
177 | template <typename U> |
178 | constexpr std:: |
179 | enable_if_t<std::is_convertible<U&&, ArrayRef<T>>::value, ArrayRef<T>> |
180 | value_or(U&& default_value) && { |
181 | return wrapped_opt_array_ref.value_or(default_value); |
182 | } |
183 | |
184 | // Modifiers |
185 | |
186 | constexpr void swap(OptionalArrayRef& other) noexcept { |
187 | std::swap(wrapped_opt_array_ref, other.wrapped_opt_array_ref); |
188 | } |
189 | |
190 | constexpr void reset() noexcept { |
191 | wrapped_opt_array_ref.reset(); |
192 | } |
193 | |
194 | template <typename... Args> |
195 | constexpr std::enable_if_t< |
196 | std::is_constructible<ArrayRef<T>, Args&&...>::value, |
197 | ArrayRef<T>&> |
198 | emplace(Args&&... args) noexcept( |
199 | std::is_nothrow_constructible<ArrayRef<T>, Args&&...>::value) { |
200 | return wrapped_opt_array_ref.emplace(args...); |
201 | } |
202 | |
203 | template <typename U, typename... Args> |
204 | constexpr ArrayRef<T>& emplace( |
205 | std::initializer_list<U> il, |
206 | Args&&... args) noexcept { |
207 | return wrapped_opt_array_ref.emplace(il, args...); |
208 | } |
209 | |
210 | private: |
211 | optional<ArrayRef<T>> wrapped_opt_array_ref; |
212 | }; |
213 | |
214 | using OptionalIntArrayRef = OptionalArrayRef<int64_t>; |
215 | |
216 | inline bool operator==( |
217 | const OptionalIntArrayRef& a1, |
218 | const IntArrayRef& other) { |
219 | if (!a1.has_value()) { |
220 | return false; |
221 | } |
222 | return a1.value() == other; |
223 | } |
224 | |
225 | inline bool operator==( |
226 | const c10::IntArrayRef& a1, |
227 | const c10::OptionalIntArrayRef& a2) { |
228 | return a2 == a1; |
229 | } |
230 | |
231 | } // namespace c10 |
232 | |