1//===--- ArrayRef.h - Array Reference Wrapper -------------------*- C++ -*-===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9
10// ATen: modified from llvm::ArrayRef.
11// removed llvm-specific functionality
12// removed some implicit const -> non-const conversions that rely on
13// complicated std::enable_if meta-programming
14// removed a bunch of slice variants for simplicity...
15
16#pragma once
17
18#include <c10/util/C++17.h>
19#include <c10/util/Deprecated.h>
20#include <c10/util/Exception.h>
21#include <c10/util/SmallVector.h>
22
23#include <array>
24#include <iterator>
25#include <vector>
26
27namespace c10 {
28/// ArrayRef - Represent a constant reference to an array (0 or more elements
29/// consecutively in memory), i.e. a start pointer and a length. It allows
30/// various APIs to take consecutive elements easily and conveniently.
31///
32/// This class does not own the underlying data, it is expected to be used in
33/// situations where the data resides in some other buffer, whose lifetime
34/// extends past that of the ArrayRef. For this reason, it is not in general
35/// safe to store an ArrayRef.
36///
37/// This is intended to be trivially copyable, so it should be passed by
38/// value.
39template <typename T>
40class ArrayRef final {
41 public:
42 using iterator = const T*;
43 using const_iterator = const T*;
44 using size_type = size_t;
45 using value_type = T;
46
47 using reverse_iterator = std::reverse_iterator<iterator>;
48
49 private:
50 /// The start of the array, in an external buffer.
51 const T* Data;
52
53 /// The number of elements.
54 size_type Length;
55
56 void debugCheckNullptrInvariant() {
57 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
58 Data != nullptr || Length == 0,
59 "created ArrayRef with nullptr and non-zero length! c10::optional relies on this being illegal");
60 }
61
62 public:
63 /// @name Constructors
64 /// @{
65
66 /// Construct an empty ArrayRef.
67 /* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {}
68
69 /// Construct an ArrayRef from a single element.
70 // TODO Make this explicit
71 constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {}
72
73 /// Construct an ArrayRef from a pointer and length.
74 C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef(const T* data, size_t length)
75 : Data(data), Length(length) {
76 debugCheckNullptrInvariant();
77 }
78
79 /// Construct an ArrayRef from a range.
80 C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef(const T* begin, const T* end)
81 : Data(begin), Length(end - begin) {
82 debugCheckNullptrInvariant();
83 }
84
85 /// Construct an ArrayRef from a SmallVector. This is templated in order to
86 /// avoid instantiating SmallVectorTemplateCommon<T> whenever we
87 /// copy-construct an ArrayRef.
88 template <typename U>
89 /* implicit */ ArrayRef(const SmallVectorTemplateCommon<T, U>& Vec)
90 : Data(Vec.data()), Length(Vec.size()) {
91 debugCheckNullptrInvariant();
92 }
93
94 template <
95 typename Container,
96 typename = std::enable_if_t<std::is_same<
97 std::remove_const_t<decltype(std::declval<Container>().data())>,
98 T*>::value>>
99 /* implicit */ ArrayRef(const Container& container)
100 : Data(container.data()), Length(container.size()) {
101 debugCheckNullptrInvariant();
102 }
103
104 /// Construct an ArrayRef from a std::vector.
105 // The enable_if stuff here makes sure that this isn't used for
106 // std::vector<bool>, because ArrayRef can't work on a std::vector<bool>
107 // bitfield.
108 template <typename A>
109 /* implicit */ ArrayRef(const std::vector<T, A>& Vec)
110 : Data(Vec.data()), Length(Vec.size()) {
111 static_assert(
112 !std::is_same<T, bool>::value,
113 "ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
114 }
115
116 /// Construct an ArrayRef from a std::array
117 template <size_t N>
118 /* implicit */ constexpr ArrayRef(const std::array<T, N>& Arr)
119 : Data(Arr.data()), Length(N) {}
120
121 /// Construct an ArrayRef from a C array.
122 template <size_t N>
123 /* implicit */ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {}
124
125 /// Construct an ArrayRef from a std::initializer_list.
126 /* implicit */ constexpr ArrayRef(const std::initializer_list<T>& Vec)
127 : Data(
128 std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr)
129 : std::begin(Vec)),
130 Length(Vec.size()) {}
131
132 /// @}
133 /// @name Simple Operations
134 /// @{
135
136 constexpr iterator begin() const {
137 return Data;
138 }
139 constexpr iterator end() const {
140 return Data + Length;
141 }
142
143 // These are actually the same as iterator, since ArrayRef only
144 // gives you const iterators.
145 constexpr const_iterator cbegin() const {
146 return Data;
147 }
148 constexpr const_iterator cend() const {
149 return Data + Length;
150 }
151
152 constexpr reverse_iterator rbegin() const {
153 return reverse_iterator(end());
154 }
155 constexpr reverse_iterator rend() const {
156 return reverse_iterator(begin());
157 }
158
159 /// empty - Check if the array is empty.
160 constexpr bool empty() const {
161 return Length == 0;
162 }
163
164 constexpr const T* data() const {
165 return Data;
166 }
167
168 /// size - Get the array size.
169 constexpr size_t size() const {
170 return Length;
171 }
172
173 /// front - Get the first element.
174 C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& front() const {
175 TORCH_CHECK(
176 !empty(), "ArrayRef: attempted to access front() of empty list");
177 return Data[0];
178 }
179
180 /// back - Get the last element.
181 C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& back() const {
182 TORCH_CHECK(!empty(), "ArrayRef: attempted to access back() of empty list");
183 return Data[Length - 1];
184 }
185
186 /// equals - Check for element-wise equality.
187 constexpr bool equals(ArrayRef RHS) const {
188 return Length == RHS.Length && std::equal(begin(), end(), RHS.begin());
189 }
190
191 /// slice(n, m) - Take M elements of the array starting at element N
192 C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef<T> slice(size_t N, size_t M)
193 const {
194 TORCH_CHECK(
195 N + M <= size(),
196 "ArrayRef: invalid slice, N = ",
197 N,
198 "; M = ",
199 M,
200 "; size = ",
201 size());
202 return ArrayRef<T>(data() + N, M);
203 }
204
205 /// slice(n) - Chop off the first N elements of the array.
206 constexpr ArrayRef<T> slice(size_t N) const {
207 return slice(N, size() - N);
208 }
209
210 /// @}
211 /// @name Operator Overloads
212 /// @{
213 constexpr const T& operator[](size_t Index) const {
214 return Data[Index];
215 }
216
217 /// Vector compatibility
218 C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& at(size_t Index) const {
219 TORCH_CHECK(
220 Index < Length,
221 "ArrayRef: invalid index Index = ",
222 Index,
223 "; Length = ",
224 Length);
225 return Data[Index];
226 }
227
228 /// Disallow accidental assignment from a temporary.
229 ///
230 /// The declaration here is extra complicated so that "arrayRef = {}"
231 /// continues to select the move assignment operator.
232 template <typename U>
233 typename std::enable_if<std::is_same<U, T>::value, ArrayRef<T>>::type&
234 operator=(U&& Temporary) = delete;
235
236 /// Disallow accidental assignment from a temporary.
237 ///
238 /// The declaration here is extra complicated so that "arrayRef = {}"
239 /// continues to select the move assignment operator.
240 template <typename U>
241 typename std::enable_if<std::is_same<U, T>::value, ArrayRef<T>>::type&
242 operator=(std::initializer_list<U>) = delete;
243
244 /// @}
245 /// @name Expensive Operations
246 /// @{
247 std::vector<T> vec() const {
248 return std::vector<T>(Data, Data + Length);
249 }
250
251 /// @}
252};
253
254template <typename T>
255std::ostream& operator<<(std::ostream& out, ArrayRef<T> list) {
256 int i = 0;
257 out << "[";
258 for (const auto& e : list) {
259 if (i++ > 0)
260 out << ", ";
261 out << e;
262 }
263 out << "]";
264 return out;
265}
266
267/// @name ArrayRef Convenience constructors
268/// @{
269
270/// Construct an ArrayRef from a single element.
271template <typename T>
272ArrayRef<T> makeArrayRef(const T& OneElt) {
273 return OneElt;
274}
275
276/// Construct an ArrayRef from a pointer and length.
277template <typename T>
278ArrayRef<T> makeArrayRef(const T* data, size_t length) {
279 return ArrayRef<T>(data, length);
280}
281
282/// Construct an ArrayRef from a range.
283template <typename T>
284ArrayRef<T> makeArrayRef(const T* begin, const T* end) {
285 return ArrayRef<T>(begin, end);
286}
287
288/// Construct an ArrayRef from a SmallVector.
289template <typename T>
290ArrayRef<T> makeArrayRef(const SmallVectorImpl<T>& Vec) {
291 return Vec;
292}
293
294/// Construct an ArrayRef from a SmallVector.
295template <typename T, unsigned N>
296ArrayRef<T> makeArrayRef(const SmallVector<T, N>& Vec) {
297 return Vec;
298}
299
300/// Construct an ArrayRef from a std::vector.
301template <typename T>
302ArrayRef<T> makeArrayRef(const std::vector<T>& Vec) {
303 return Vec;
304}
305
306/// Construct an ArrayRef from a std::array.
307template <typename T, std::size_t N>
308ArrayRef<T> makeArrayRef(const std::array<T, N>& Arr) {
309 return Arr;
310}
311
312/// Construct an ArrayRef from an ArrayRef (no-op) (const)
313template <typename T>
314ArrayRef<T> makeArrayRef(const ArrayRef<T>& Vec) {
315 return Vec;
316}
317
318/// Construct an ArrayRef from an ArrayRef (no-op)
319template <typename T>
320ArrayRef<T>& makeArrayRef(ArrayRef<T>& Vec) {
321 return Vec;
322}
323
324/// Construct an ArrayRef from a C array.
325template <typename T, size_t N>
326ArrayRef<T> makeArrayRef(const T (&Arr)[N]) {
327 return ArrayRef<T>(Arr);
328}
329
330// WARNING: Template instantiation will NOT be willing to do an implicit
331// conversions to get you to an c10::ArrayRef, which is why we need so
332// many overloads.
333
334template <typename T>
335bool operator==(c10::ArrayRef<T> a1, c10::ArrayRef<T> a2) {
336 return a1.equals(a2);
337}
338
339template <typename T>
340bool operator!=(c10::ArrayRef<T> a1, c10::ArrayRef<T> a2) {
341 return !a1.equals(a2);
342}
343
344template <typename T>
345bool operator==(const std::vector<T>& a1, c10::ArrayRef<T> a2) {
346 return c10::ArrayRef<T>(a1).equals(a2);
347}
348
349template <typename T>
350bool operator!=(const std::vector<T>& a1, c10::ArrayRef<T> a2) {
351 return !c10::ArrayRef<T>(a1).equals(a2);
352}
353
354template <typename T>
355bool operator==(c10::ArrayRef<T> a1, const std::vector<T>& a2) {
356 return a1.equals(c10::ArrayRef<T>(a2));
357}
358
359template <typename T>
360bool operator!=(c10::ArrayRef<T> a1, const std::vector<T>& a2) {
361 return !a1.equals(c10::ArrayRef<T>(a2));
362}
363
364using IntArrayRef = ArrayRef<int64_t>;
365
366// This alias is deprecated because it doesn't make ownership
367// semantics obvious. Use IntArrayRef instead!
368C10_DEFINE_DEPRECATED_USING(IntList, ArrayRef<int64_t>)
369
370} // namespace c10
371