1 | //===--- ArrayRef.h - Array Reference Wrapper -------------------*- C++ -*-===// |
2 | // |
3 | // The LLVM Compiler Infrastructure |
4 | // |
5 | // This file is distributed under the University of Illinois Open Source |
6 | // License. See LICENSE.TXT for details. |
7 | // |
8 | //===----------------------------------------------------------------------===// |
9 | |
10 | // ATen: modified from llvm::ArrayRef. |
11 | // removed llvm-specific functionality |
12 | // removed some implicit const -> non-const conversions that rely on |
13 | // complicated std::enable_if meta-programming |
14 | // removed a bunch of slice variants for simplicity... |
15 | |
16 | #pragma once |
17 | |
18 | #include <c10/util/C++17.h> |
19 | #include <c10/util/Deprecated.h> |
20 | #include <c10/util/Exception.h> |
21 | #include <c10/util/SmallVector.h> |
22 | |
23 | #include <array> |
24 | #include <iterator> |
25 | #include <vector> |
26 | |
27 | namespace c10 { |
28 | /// ArrayRef - Represent a constant reference to an array (0 or more elements |
29 | /// consecutively in memory), i.e. a start pointer and a length. It allows |
30 | /// various APIs to take consecutive elements easily and conveniently. |
31 | /// |
32 | /// This class does not own the underlying data, it is expected to be used in |
33 | /// situations where the data resides in some other buffer, whose lifetime |
34 | /// extends past that of the ArrayRef. For this reason, it is not in general |
35 | /// safe to store an ArrayRef. |
36 | /// |
37 | /// This is intended to be trivially copyable, so it should be passed by |
38 | /// value. |
39 | template <typename T> |
40 | class ArrayRef final { |
41 | public: |
42 | using iterator = const T*; |
43 | using const_iterator = const T*; |
44 | using size_type = size_t; |
45 | using value_type = T; |
46 | |
47 | using reverse_iterator = std::reverse_iterator<iterator>; |
48 | |
49 | private: |
50 | /// The start of the array, in an external buffer. |
51 | const T* Data; |
52 | |
53 | /// The number of elements. |
54 | size_type Length; |
55 | |
56 | void debugCheckNullptrInvariant() { |
57 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
58 | Data != nullptr || Length == 0, |
59 | "created ArrayRef with nullptr and non-zero length! c10::optional relies on this being illegal" ); |
60 | } |
61 | |
62 | public: |
63 | /// @name Constructors |
64 | /// @{ |
65 | |
66 | /// Construct an empty ArrayRef. |
67 | /* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {} |
68 | |
69 | /// Construct an ArrayRef from a single element. |
70 | // TODO Make this explicit |
71 | constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {} |
72 | |
73 | /// Construct an ArrayRef from a pointer and length. |
74 | C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef(const T* data, size_t length) |
75 | : Data(data), Length(length) { |
76 | debugCheckNullptrInvariant(); |
77 | } |
78 | |
79 | /// Construct an ArrayRef from a range. |
80 | C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef(const T* begin, const T* end) |
81 | : Data(begin), Length(end - begin) { |
82 | debugCheckNullptrInvariant(); |
83 | } |
84 | |
85 | /// Construct an ArrayRef from a SmallVector. This is templated in order to |
86 | /// avoid instantiating SmallVectorTemplateCommon<T> whenever we |
87 | /// copy-construct an ArrayRef. |
88 | template <typename U> |
89 | /* implicit */ ArrayRef(const SmallVectorTemplateCommon<T, U>& Vec) |
90 | : Data(Vec.data()), Length(Vec.size()) { |
91 | debugCheckNullptrInvariant(); |
92 | } |
93 | |
94 | template < |
95 | typename Container, |
96 | typename = std::enable_if_t<std::is_same< |
97 | std::remove_const_t<decltype(std::declval<Container>().data())>, |
98 | T*>::value>> |
99 | /* implicit */ ArrayRef(const Container& container) |
100 | : Data(container.data()), Length(container.size()) { |
101 | debugCheckNullptrInvariant(); |
102 | } |
103 | |
104 | /// Construct an ArrayRef from a std::vector. |
105 | // The enable_if stuff here makes sure that this isn't used for |
106 | // std::vector<bool>, because ArrayRef can't work on a std::vector<bool> |
107 | // bitfield. |
108 | template <typename A> |
109 | /* implicit */ ArrayRef(const std::vector<T, A>& Vec) |
110 | : Data(Vec.data()), Length(Vec.size()) { |
111 | static_assert( |
112 | !std::is_same<T, bool>::value, |
113 | "ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield." ); |
114 | } |
115 | |
116 | /// Construct an ArrayRef from a std::array |
117 | template <size_t N> |
118 | /* implicit */ constexpr ArrayRef(const std::array<T, N>& Arr) |
119 | : Data(Arr.data()), Length(N) {} |
120 | |
121 | /// Construct an ArrayRef from a C array. |
122 | template <size_t N> |
123 | /* implicit */ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {} |
124 | |
125 | /// Construct an ArrayRef from a std::initializer_list. |
126 | /* implicit */ constexpr ArrayRef(const std::initializer_list<T>& Vec) |
127 | : Data( |
128 | std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr) |
129 | : std::begin(Vec)), |
130 | Length(Vec.size()) {} |
131 | |
132 | /// @} |
133 | /// @name Simple Operations |
134 | /// @{ |
135 | |
136 | constexpr iterator begin() const { |
137 | return Data; |
138 | } |
139 | constexpr iterator end() const { |
140 | return Data + Length; |
141 | } |
142 | |
143 | // These are actually the same as iterator, since ArrayRef only |
144 | // gives you const iterators. |
145 | constexpr const_iterator cbegin() const { |
146 | return Data; |
147 | } |
148 | constexpr const_iterator cend() const { |
149 | return Data + Length; |
150 | } |
151 | |
152 | constexpr reverse_iterator rbegin() const { |
153 | return reverse_iterator(end()); |
154 | } |
155 | constexpr reverse_iterator rend() const { |
156 | return reverse_iterator(begin()); |
157 | } |
158 | |
159 | /// empty - Check if the array is empty. |
160 | constexpr bool empty() const { |
161 | return Length == 0; |
162 | } |
163 | |
164 | constexpr const T* data() const { |
165 | return Data; |
166 | } |
167 | |
168 | /// size - Get the array size. |
169 | constexpr size_t size() const { |
170 | return Length; |
171 | } |
172 | |
173 | /// front - Get the first element. |
174 | C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& front() const { |
175 | TORCH_CHECK( |
176 | !empty(), "ArrayRef: attempted to access front() of empty list" ); |
177 | return Data[0]; |
178 | } |
179 | |
180 | /// back - Get the last element. |
181 | C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& back() const { |
182 | TORCH_CHECK(!empty(), "ArrayRef: attempted to access back() of empty list" ); |
183 | return Data[Length - 1]; |
184 | } |
185 | |
186 | /// equals - Check for element-wise equality. |
187 | constexpr bool equals(ArrayRef RHS) const { |
188 | return Length == RHS.Length && std::equal(begin(), end(), RHS.begin()); |
189 | } |
190 | |
191 | /// slice(n, m) - Take M elements of the array starting at element N |
192 | C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef<T> slice(size_t N, size_t M) |
193 | const { |
194 | TORCH_CHECK( |
195 | N + M <= size(), |
196 | "ArrayRef: invalid slice, N = " , |
197 | N, |
198 | "; M = " , |
199 | M, |
200 | "; size = " , |
201 | size()); |
202 | return ArrayRef<T>(data() + N, M); |
203 | } |
204 | |
205 | /// slice(n) - Chop off the first N elements of the array. |
206 | constexpr ArrayRef<T> slice(size_t N) const { |
207 | return slice(N, size() - N); |
208 | } |
209 | |
210 | /// @} |
211 | /// @name Operator Overloads |
212 | /// @{ |
213 | constexpr const T& operator[](size_t Index) const { |
214 | return Data[Index]; |
215 | } |
216 | |
217 | /// Vector compatibility |
218 | C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& at(size_t Index) const { |
219 | TORCH_CHECK( |
220 | Index < Length, |
221 | "ArrayRef: invalid index Index = " , |
222 | Index, |
223 | "; Length = " , |
224 | Length); |
225 | return Data[Index]; |
226 | } |
227 | |
228 | /// Disallow accidental assignment from a temporary. |
229 | /// |
230 | /// The declaration here is extra complicated so that "arrayRef = {}" |
231 | /// continues to select the move assignment operator. |
232 | template <typename U> |
233 | typename std::enable_if<std::is_same<U, T>::value, ArrayRef<T>>::type& |
234 | operator=(U&& Temporary) = delete; |
235 | |
236 | /// Disallow accidental assignment from a temporary. |
237 | /// |
238 | /// The declaration here is extra complicated so that "arrayRef = {}" |
239 | /// continues to select the move assignment operator. |
240 | template <typename U> |
241 | typename std::enable_if<std::is_same<U, T>::value, ArrayRef<T>>::type& |
242 | operator=(std::initializer_list<U>) = delete; |
243 | |
244 | /// @} |
245 | /// @name Expensive Operations |
246 | /// @{ |
247 | std::vector<T> vec() const { |
248 | return std::vector<T>(Data, Data + Length); |
249 | } |
250 | |
251 | /// @} |
252 | }; |
253 | |
254 | template <typename T> |
255 | std::ostream& operator<<(std::ostream& out, ArrayRef<T> list) { |
256 | int i = 0; |
257 | out << "[" ; |
258 | for (const auto& e : list) { |
259 | if (i++ > 0) |
260 | out << ", " ; |
261 | out << e; |
262 | } |
263 | out << "]" ; |
264 | return out; |
265 | } |
266 | |
267 | /// @name ArrayRef Convenience constructors |
268 | /// @{ |
269 | |
270 | /// Construct an ArrayRef from a single element. |
271 | template <typename T> |
272 | ArrayRef<T> makeArrayRef(const T& OneElt) { |
273 | return OneElt; |
274 | } |
275 | |
276 | /// Construct an ArrayRef from a pointer and length. |
277 | template <typename T> |
278 | ArrayRef<T> makeArrayRef(const T* data, size_t length) { |
279 | return ArrayRef<T>(data, length); |
280 | } |
281 | |
282 | /// Construct an ArrayRef from a range. |
283 | template <typename T> |
284 | ArrayRef<T> makeArrayRef(const T* begin, const T* end) { |
285 | return ArrayRef<T>(begin, end); |
286 | } |
287 | |
288 | /// Construct an ArrayRef from a SmallVector. |
289 | template <typename T> |
290 | ArrayRef<T> makeArrayRef(const SmallVectorImpl<T>& Vec) { |
291 | return Vec; |
292 | } |
293 | |
294 | /// Construct an ArrayRef from a SmallVector. |
295 | template <typename T, unsigned N> |
296 | ArrayRef<T> makeArrayRef(const SmallVector<T, N>& Vec) { |
297 | return Vec; |
298 | } |
299 | |
300 | /// Construct an ArrayRef from a std::vector. |
301 | template <typename T> |
302 | ArrayRef<T> makeArrayRef(const std::vector<T>& Vec) { |
303 | return Vec; |
304 | } |
305 | |
306 | /// Construct an ArrayRef from a std::array. |
307 | template <typename T, std::size_t N> |
308 | ArrayRef<T> makeArrayRef(const std::array<T, N>& Arr) { |
309 | return Arr; |
310 | } |
311 | |
312 | /// Construct an ArrayRef from an ArrayRef (no-op) (const) |
313 | template <typename T> |
314 | ArrayRef<T> makeArrayRef(const ArrayRef<T>& Vec) { |
315 | return Vec; |
316 | } |
317 | |
318 | /// Construct an ArrayRef from an ArrayRef (no-op) |
319 | template <typename T> |
320 | ArrayRef<T>& makeArrayRef(ArrayRef<T>& Vec) { |
321 | return Vec; |
322 | } |
323 | |
324 | /// Construct an ArrayRef from a C array. |
325 | template <typename T, size_t N> |
326 | ArrayRef<T> makeArrayRef(const T (&Arr)[N]) { |
327 | return ArrayRef<T>(Arr); |
328 | } |
329 | |
330 | // WARNING: Template instantiation will NOT be willing to do an implicit |
331 | // conversions to get you to an c10::ArrayRef, which is why we need so |
332 | // many overloads. |
333 | |
334 | template <typename T> |
335 | bool operator==(c10::ArrayRef<T> a1, c10::ArrayRef<T> a2) { |
336 | return a1.equals(a2); |
337 | } |
338 | |
339 | template <typename T> |
340 | bool operator!=(c10::ArrayRef<T> a1, c10::ArrayRef<T> a2) { |
341 | return !a1.equals(a2); |
342 | } |
343 | |
344 | template <typename T> |
345 | bool operator==(const std::vector<T>& a1, c10::ArrayRef<T> a2) { |
346 | return c10::ArrayRef<T>(a1).equals(a2); |
347 | } |
348 | |
349 | template <typename T> |
350 | bool operator!=(const std::vector<T>& a1, c10::ArrayRef<T> a2) { |
351 | return !c10::ArrayRef<T>(a1).equals(a2); |
352 | } |
353 | |
354 | template <typename T> |
355 | bool operator==(c10::ArrayRef<T> a1, const std::vector<T>& a2) { |
356 | return a1.equals(c10::ArrayRef<T>(a2)); |
357 | } |
358 | |
359 | template <typename T> |
360 | bool operator!=(c10::ArrayRef<T> a1, const std::vector<T>& a2) { |
361 | return !a1.equals(c10::ArrayRef<T>(a2)); |
362 | } |
363 | |
364 | using IntArrayRef = ArrayRef<int64_t>; |
365 | |
366 | // This alias is deprecated because it doesn't make ownership |
367 | // semantics obvious. Use IntArrayRef instead! |
368 | C10_DEFINE_DEPRECATED_USING(IntList, ArrayRef<int64_t>) |
369 | |
370 | } // namespace c10 |
371 | |