1// This file defines OptionalArrayRef<T>, a class that has almost the same
2// exact functionality as c10::optional<ArrayRef<T>>, except that its
3// converting constructor fixes a dangling pointer issue.
4//
5// The implicit converting constructor of both c10::optional<ArrayRef<T>> and
6// std::optional<ArrayRef<T>> can cause the underlying ArrayRef<T> to store
7// a dangling pointer. OptionalArrayRef<T> prevents this by wrapping
8// a c10::optional<ArrayRef<T>> and fixing the constructor implementation.
9//
10// See https://github.com/pytorch/pytorch/issues/63645 for more on this.
11
12#pragma once
13
14#include <c10/util/ArrayRef.h>
15#include <c10/util/Optional.h>
16
17namespace c10 {
18
19template <typename T>
20class OptionalArrayRef final {
21 public:
22 // Constructors
23
24 constexpr OptionalArrayRef() noexcept = default;
25
26 constexpr OptionalArrayRef(nullopt_t) noexcept {}
27
28 OptionalArrayRef(const OptionalArrayRef& other) = default;
29
30 OptionalArrayRef(OptionalArrayRef&& other) = default;
31
32 constexpr OptionalArrayRef(const optional<ArrayRef<T>>& other) noexcept
33 : wrapped_opt_array_ref(other) {}
34
35 constexpr OptionalArrayRef(optional<ArrayRef<T>>&& other) noexcept
36 : wrapped_opt_array_ref(other) {}
37
38 constexpr OptionalArrayRef(const T& value) noexcept
39 : wrapped_opt_array_ref(value) {}
40
41 template <
42 typename U = ArrayRef<T>,
43 std::enable_if_t<
44 !std::is_same<std::decay_t<U>, OptionalArrayRef>::value &&
45 !std::is_same<std::decay_t<U>, in_place_t>::value &&
46 std::is_constructible<ArrayRef<T>, U&&>::value &&
47 std::is_convertible<U&&, ArrayRef<T>>::value &&
48 !std::is_convertible<U&&, T>::value,
49 bool> = false>
50 constexpr OptionalArrayRef(U&& value) noexcept(
51 std::is_nothrow_constructible<ArrayRef<T>, U&&>::value)
52 : wrapped_opt_array_ref(value) {}
53
54 template <
55 typename U = ArrayRef<T>,
56 std::enable_if_t<
57 !std::is_same<std::decay_t<U>, OptionalArrayRef>::value &&
58 !std::is_same<std::decay_t<U>, in_place_t>::value &&
59 std::is_constructible<ArrayRef<T>, U&&>::value &&
60 !std::is_convertible<U&&, ArrayRef<T>>::value,
61 bool> = false>
62 constexpr explicit OptionalArrayRef(U&& value) noexcept(
63 std::is_nothrow_constructible<ArrayRef<T>, U&&>::value)
64 : wrapped_opt_array_ref(value) {}
65
66 template <typename... Args>
67 constexpr explicit OptionalArrayRef(in_place_t ip, Args&&... args) noexcept
68 : wrapped_opt_array_ref(ip, args...) {}
69
70 template <typename U, typename... Args>
71 constexpr explicit OptionalArrayRef(
72 in_place_t ip,
73 std::initializer_list<U> il,
74 Args&&... args)
75 : wrapped_opt_array_ref(ip, il, args...) {}
76
77 constexpr OptionalArrayRef(const std::initializer_list<T>& Vec)
78 : wrapped_opt_array_ref(ArrayRef<T>(Vec)) {}
79
80 // Destructor
81
82 ~OptionalArrayRef() = default;
83
84 // Assignment
85
86 constexpr OptionalArrayRef& operator=(nullopt_t) noexcept {
87 wrapped_opt_array_ref = c10::nullopt;
88 return *this;
89 }
90
91 OptionalArrayRef& operator=(const OptionalArrayRef& other) = default;
92
93 OptionalArrayRef& operator=(OptionalArrayRef&& other) = default;
94
95 constexpr OptionalArrayRef& operator=(
96 const optional<ArrayRef<T>>& other) noexcept {
97 wrapped_opt_array_ref = other;
98 return *this;
99 }
100
101 constexpr OptionalArrayRef& operator=(
102 optional<ArrayRef<T>>&& other) noexcept {
103 wrapped_opt_array_ref = other;
104 return *this;
105 }
106
107 template <typename U = ArrayRef<T>>
108 constexpr std::enable_if_t<
109 !std::is_same<std::decay_t<U>, OptionalArrayRef>::value &&
110 std::is_constructible<ArrayRef<T>, U&&>::value &&
111 std::is_assignable<ArrayRef<T>&, U&&>::value,
112 OptionalArrayRef&>
113 operator=(U&& value) noexcept(
114 std::is_nothrow_constructible<ArrayRef<T>, U&&>::value&&
115 std::is_nothrow_assignable<ArrayRef<T>&, U&&>::value) {
116 wrapped_opt_array_ref = value;
117 return *this;
118 }
119
120 // Observers
121
122 constexpr ArrayRef<T>* operator->() noexcept {
123 return &wrapped_opt_array_ref.value();
124 }
125
126 constexpr const ArrayRef<T>* operator->() const noexcept {
127 return &wrapped_opt_array_ref.value();
128 }
129
130 constexpr ArrayRef<T>& operator*() & noexcept {
131 return wrapped_opt_array_ref.value();
132 }
133
134 constexpr const ArrayRef<T>& operator*() const& noexcept {
135 return wrapped_opt_array_ref.value();
136 }
137
138 constexpr ArrayRef<T>&& operator*() && noexcept {
139 return std::move(wrapped_opt_array_ref.value());
140 }
141
142 constexpr const ArrayRef<T>&& operator*() const&& noexcept {
143 return std::move(wrapped_opt_array_ref.value());
144 }
145
146 constexpr explicit operator bool() const noexcept {
147 return wrapped_opt_array_ref.has_value();
148 }
149
150 constexpr bool has_value() const noexcept {
151 return wrapped_opt_array_ref.has_value();
152 }
153
154 constexpr ArrayRef<T>& value() & {
155 return wrapped_opt_array_ref.value();
156 }
157
158 constexpr const ArrayRef<T>& value() const& {
159 return wrapped_opt_array_ref.value();
160 }
161
162 constexpr ArrayRef<T>&& value() && {
163 return std::move(wrapped_opt_array_ref.value());
164 }
165
166 constexpr const ArrayRef<T>&& value() const&& {
167 return std::move(wrapped_opt_array_ref.value());
168 }
169
170 template <typename U>
171 constexpr std::
172 enable_if_t<std::is_convertible<U&&, ArrayRef<T>>::value, ArrayRef<T>>
173 value_or(U&& default_value) const& {
174 return wrapped_opt_array_ref.value_or(default_value);
175 }
176
177 template <typename U>
178 constexpr std::
179 enable_if_t<std::is_convertible<U&&, ArrayRef<T>>::value, ArrayRef<T>>
180 value_or(U&& default_value) && {
181 return wrapped_opt_array_ref.value_or(default_value);
182 }
183
184 // Modifiers
185
186 constexpr void swap(OptionalArrayRef& other) noexcept {
187 std::swap(wrapped_opt_array_ref, other.wrapped_opt_array_ref);
188 }
189
190 constexpr void reset() noexcept {
191 wrapped_opt_array_ref.reset();
192 }
193
194 template <typename... Args>
195 constexpr std::enable_if_t<
196 std::is_constructible<ArrayRef<T>, Args&&...>::value,
197 ArrayRef<T>&>
198 emplace(Args&&... args) noexcept(
199 std::is_nothrow_constructible<ArrayRef<T>, Args&&...>::value) {
200 return wrapped_opt_array_ref.emplace(args...);
201 }
202
203 template <typename U, typename... Args>
204 constexpr ArrayRef<T>& emplace(
205 std::initializer_list<U> il,
206 Args&&... args) noexcept {
207 return wrapped_opt_array_ref.emplace(il, args...);
208 }
209
210 private:
211 optional<ArrayRef<T>> wrapped_opt_array_ref;
212};
213
214using OptionalIntArrayRef = OptionalArrayRef<int64_t>;
215
216inline bool operator==(
217 const OptionalIntArrayRef& a1,
218 const IntArrayRef& other) {
219 if (!a1.has_value()) {
220 return false;
221 }
222 return a1.value() == other;
223}
224
225inline bool operator==(
226 const c10::IntArrayRef& a1,
227 const c10::OptionalIntArrayRef& a2) {
228 return a2 == a1;
229}
230
231} // namespace c10
232