1#pragma once
2
3#include <ATen/core/ivalue_to.h>
4#include <ATen/core/jit_type_base.h>
5#include <c10/macros/Macros.h>
6#include <c10/macros/Export.h>
7#include <c10/util/TypeTraits.h>
8#include <c10/util/TypeList.h>
9#include <c10/util/intrusive_ptr.h>
10#include <c10/util/ArrayRef.h>
11#include <c10/util/Optional.h>
12#include <vector>
13
14namespace at {
15class Tensor;
16}
17namespace c10 {
18struct IValue;
19template<class T> class List;
20struct Type;
21
22namespace detail {
23
24struct ListImpl final : public c10::intrusive_ptr_target {
25 using list_type = std::vector<IValue>;
26
27 explicit ListImpl(list_type list_, TypePtr elementType_)
28 : list(std::move(list_))
29 , elementType(std::move(elementType_)) {}
30
31 list_type list;
32
33 TypePtr elementType;
34
35 intrusive_ptr<ListImpl> copy() const {
36 return make_intrusive<ListImpl>(list, elementType);
37 }
38 friend TORCH_API bool operator==(const ListImpl& lhs, const ListImpl& rhs);
39};
40}
41
42namespace impl {
43
44template<class T, class Iterator> class ListIterator;
45
46template<class T, class Iterator> class ListElementReference;
47
48template<class T, class Iterator>
49void swap(ListElementReference<T, Iterator>&& lhs, ListElementReference<T, Iterator>&& rhs);
50
51template<class T, class Iterator>
52bool operator==(const ListElementReference<T, Iterator>& lhs, const T& rhs);
53
54template<class T, class Iterator>
55bool operator==(const T& lhs, const ListElementReference<T, Iterator>& rhs);
56
57template<class T>
58struct ListElementConstReferenceTraits {
59 // In the general case, we use IValue::to().
60 using const_reference = typename c10::detail::ivalue_to_const_ref_overload_return<T>::type;
61};
62
63// There is no to() overload for c10::optional<std::string>.
64template<>
65struct ListElementConstReferenceTraits<c10::optional<std::string>> {
66 using const_reference = c10::optional<std::reference_wrapper<const std::string>>;
67};
68
69template<class T, class Iterator>
70class ListElementReference final {
71public:
72 operator std::conditional_t<
73 std::is_reference<typename c10::detail::
74 ivalue_to_const_ref_overload_return<T>::type>::value,
75 const T&,
76 T>() const;
77
78 ListElementReference& operator=(T&& new_value) &&;
79
80 ListElementReference& operator=(const T& new_value) &&;
81
82 // assigning another ref to this assigns the underlying value
83 ListElementReference& operator=(ListElementReference&& rhs) &&;
84
85 const IValue& get() const& {
86 return *iterator_;
87 }
88
89 friend void swap<T, Iterator>(ListElementReference&& lhs, ListElementReference&& rhs);
90
91 ListElementReference(const ListElementReference&) = delete;
92 ListElementReference& operator=(const ListElementReference&) = delete;
93
94private:
95 ListElementReference(Iterator iter)
96 : iterator_(iter) {}
97
98 // allow moving, but only our friends (i.e. the List class) can move us
99 ListElementReference(ListElementReference&&) noexcept = default;
100 ListElementReference& operator=(ListElementReference&& rhs) & noexcept {
101 iterator_ = std::move(rhs.iterator_);
102 return *this;
103 }
104
105 friend class List<T>;
106 friend class ListIterator<T, Iterator>;
107
108 Iterator iterator_;
109};
110
111// this wraps vector::iterator to make sure user code can't rely
112// on it being the type of the underlying vector.
113template <class T, class Iterator>
114class ListIterator final {
115 public:
116 // C++17 friendly std::iterator implementation
117 using iterator_category = std::random_access_iterator_tag;
118 using value_type = T;
119 using difference_type = std::ptrdiff_t;
120 using pointer = T*;
121 using reference = ListElementReference<T, Iterator>;
122
123 explicit ListIterator() = default;
124 ~ListIterator() = default;
125
126 ListIterator(const ListIterator&) = default;
127 ListIterator(ListIterator&&) noexcept = default;
128 ListIterator& operator=(const ListIterator&) = default;
129 ListIterator& operator=(ListIterator&&) = default;
130
131 ListIterator& operator++() {
132 ++iterator_;
133 return *this;
134 }
135
136 ListIterator operator++(int) {
137 ListIterator copy(*this);
138 ++*this;
139 return copy;
140 }
141
142 ListIterator& operator--() {
143 --iterator_;
144 return *this;
145 }
146
147 ListIterator operator--(int) {
148 ListIterator copy(*this);
149 --*this;
150 return copy;
151 }
152
153 ListIterator& operator+=(typename List<T>::size_type offset) {
154 iterator_ += offset;
155 return *this;
156 }
157
158 ListIterator& operator-=(typename List<T>::size_type offset) {
159 iterator_ -= offset;
160 return *this;
161 }
162
163 ListIterator operator+(typename List<T>::size_type offset) const {
164 return ListIterator{iterator_ + offset};
165 }
166
167 ListIterator operator-(typename List<T>::size_type offset) const {
168 return ListIterator{iterator_ - offset};
169 }
170
171 friend difference_type operator-(const ListIterator& lhs, const ListIterator& rhs) {
172 return lhs.iterator_ - rhs.iterator_;
173 }
174
175 ListElementReference<T, Iterator> operator*() const {
176 return {iterator_};
177 }
178
179 ListElementReference<T, Iterator> operator[](typename List<T>::size_type offset) const {
180 return {iterator_ + offset};
181 }
182
183private:
184 explicit ListIterator(Iterator iterator): iterator_(std::move(iterator)) {}
185
186 Iterator iterator_;
187
188 friend bool operator==(const ListIterator& lhs, const ListIterator& rhs) {
189 return lhs.iterator_ == rhs.iterator_;
190 }
191
192 friend bool operator!=(const ListIterator& lhs, const ListIterator& rhs) {
193 return !(lhs == rhs);
194 }
195
196 friend bool operator<(const ListIterator& lhs, const ListIterator& rhs) {
197 return lhs.iterator_ < rhs.iterator_;
198 }
199
200 friend bool operator<=(const ListIterator& lhs, const ListIterator& rhs) {
201 return lhs.iterator_ <= rhs.iterator_;
202 }
203
204 friend bool operator>(const ListIterator& lhs, const ListIterator& rhs) {
205 return lhs.iterator_ > rhs.iterator_;
206 }
207
208 friend bool operator>=(const ListIterator& lhs, const ListIterator& rhs) {
209 return lhs.iterator_ >= rhs.iterator_;
210 }
211
212 friend class ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>;
213 friend class List<T>;
214};
215
216template<class T> List<T> toTypedList(List<IValue> list);
217template<class T> List<IValue> toList(List<T>&& list);
218template<class T> List<IValue> toList(const List<T>& list);
219const IValue* ptr_to_first_element(const List<IValue>& list);
220}
221
222/**
223 * An object of this class stores a list of values of type T.
224 *
225 * This is a pointer type. After a copy, both Lists
226 * will share the same storage:
227 *
228 * > List<int> a;
229 * > List<int> b = a;
230 * > b.push_back("three");
231 * > ASSERT("three" == a.get(0));
232 *
233 * We use this class in the PyTorch kernel API instead of
234 * std::vector<T>, because that allows us to do optimizations
235 * and switch out the underlying list implementation without
236 * breaking backwards compatibility for the kernel API.
237 */
238template<class T>
239class List final {
240private:
241 // This is an intrusive_ptr because List is a pointer type.
242 // Invariant: This will never be a nullptr, there will always be a valid
243 // ListImpl.
244 c10::intrusive_ptr<c10::detail::ListImpl> impl_;
245
246 using internal_reference_type = impl::ListElementReference<T, typename c10::detail::ListImpl::list_type::iterator>;
247 using internal_const_reference_type = typename impl::ListElementConstReferenceTraits<T>::const_reference;
248
249public:
250 using value_type = T;
251 using size_type = typename c10::detail::ListImpl::list_type::size_type;
252 using iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>;
253 using const_iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::iterator>;
254 using reverse_iterator = impl::ListIterator<T, typename c10::detail::ListImpl::list_type::reverse_iterator>;
255
256 /**
257 * Constructs an empty list.
258 */
259 explicit List();
260
261 /**
262 * Constructs a list with some initial values.
263 * Example:
264 * List<int> a({2, 3, 4});
265 */
266 List(std::initializer_list<T> initial_values);
267 explicit List(ArrayRef<T> initial_values);
268
269 /**
270 * Create a generic list with runtime type information.
271 * This only works for c10::impl::GenericList and is not part of the public API
272 * but only supposed to be used internally by PyTorch.
273 */
274 explicit List(TypePtr elementType);
275
276 List(const List&) = default;
277 List& operator=(const List&) = default;
278
279 /**
280 * Create a new List pointing to a deep copy of the same data.
281 * The List returned is a new list with separate storage.
282 * Changes in it are not reflected in the original list or vice versa.
283 */
284 List copy() const;
285
286 /**
287 * Returns the element at specified location pos, with bounds checking.
288 * If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
289 */
290 value_type get(size_type pos) const;
291
292 /**
293 * Moves out the element at the specified location pos and returns it, with bounds checking.
294 * If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
295 * The list contains an invalid element at position pos afterwards. Any operations
296 * on it before re-setting it are invalid.
297 */
298 value_type extract(size_type pos) const;
299
300 /**
301 * Returns a reference to the element at specified location pos, with bounds checking.
302 * If pos is not within the range of the container, an exception of type std::out_of_range is thrown.
303 *
304 * You cannot store the reference, but you can read it and assign new values to it:
305 *
306 * List<int64_t> list = ...;
307 * list[2] = 5;
308 * int64_t v = list[1];
309 */
310 internal_const_reference_type operator[](size_type pos) const;
311
312 internal_reference_type operator[](size_type pos);
313
314 /**
315 * Assigns a new value to the element at location pos.
316 */
317 void set(size_type pos, const value_type& value) const;
318
319 /**
320 * Assigns a new value to the element at location pos.
321 */
322 void set(size_type pos, value_type&& value) const;
323
324 /**
325 * Returns an iterator to the first element of the container.
326 * If the container is empty, the returned iterator will be equal to end().
327 */
328 iterator begin() const;
329
330 /**
331 * Returns an iterator to the element following the last element of the container.
332 * This element acts as a placeholder; attempting to access it results in undefined behavior.
333 */
334 iterator end() const;
335
336 /**
337 * Checks if the container has no elements.
338 */
339 bool empty() const;
340
341 /**
342 * Returns the number of elements in the container
343 */
344 size_type size() const;
345
346 /**
347 * Increase the capacity of the vector to a value that's greater or equal to new_cap.
348 */
349 void reserve(size_type new_cap) const;
350
351 /**
352 * Erases all elements from the container. After this call, size() returns zero.
353 * Invalidates any references, pointers, or iterators referring to contained elements. Any past-the-end iterators are also invalidated.
354 */
355 void clear() const;
356
357 /**
358 * Inserts value before pos.
359 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
360 */
361 iterator insert(iterator pos, const T& value) const;
362
363 /**
364 * Inserts value before pos.
365 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
366 */
367 iterator insert(iterator pos, T&& value) const;
368
369 /**
370 * Inserts a new element into the container directly before pos.
371 * The new element is constructed with the given arguments.
372 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
373 */
374 template<class... Args>
375 iterator emplace(iterator pos, Args&&... value) const;
376
377 /**
378 * Appends the given element value to the end of the container.
379 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
380 */
381 void push_back(const T& value) const;
382
383 /**
384 * Appends the given element value to the end of the container.
385 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
386 */
387 void push_back(T&& value) const;
388
389 /**
390 * Appends the given list to the end of the container. Uses at most one memory allocation.
391 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
392 */
393 void append(List<T> lst) const;
394
395 /**
396 * Appends the given element value to the end of the container.
397 * The new element is constructed with the given arguments.
398 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
399 */
400 template<class... Args>
401 void emplace_back(Args&&... args) const;
402
403 /**
404 * Removes the element at pos.
405 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
406 */
407 iterator erase(iterator pos) const;
408
409 /**
410 * Removes the elements in the range [first, last).
411 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
412 */
413 iterator erase(iterator first, iterator last) const;
414
415 /**
416 * Removes the last element of the container.
417 * Calling pop_back on an empty container is undefined.
418 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
419 */
420 void pop_back() const;
421
422 /**
423 * Resizes the container to contain count elements.
424 * If the current size is less than count, additional default-inserted elements are appended.
425 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
426 */
427 void resize(size_type count) const;
428
429 /**
430 * Resizes the container to contain count elements.
431 * If the current size is less than count, additional copies of value are appended.
432 * May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
433 */
434 void resize(size_type count, const T& value) const;
435
436 /**
437 * Value equality comparison. This function implements Python-like semantics for
438 * equality: two lists with the same identity (e.g. same pointer) trivially
439 * compare equal, otherwise each element is compared for equality.
440 */
441 template <class T_>
442 friend bool operator==(const List<T_>& lhs, const List<T_>& rhs);
443
444 template <class T_>
445 friend bool operator!=(const List<T_>& lhs, const List<T_>& rhs);
446
447 /**
448 * Identity comparison. Returns true if and only if `rhs` represents the same
449 * List object as `this`.
450 */
451 bool is(const List<T>& rhs) const;
452
453 std::vector<T> vec() const;
454
455 /**
456 * Returns the number of Lists currently pointing to this same list.
457 * If this is the only instance pointing to this list, returns 1.
458 */
459 // TODO Test use_count
460 size_t use_count() const;
461
462 TypePtr elementType() const;
463
464 // See [unsafe set type] for why this exists.
465 void unsafeSetElementType(TypePtr t);
466
467private:
468 explicit List(c10::intrusive_ptr<c10::detail::ListImpl>&& elements);
469 explicit List(const c10::intrusive_ptr<c10::detail::ListImpl>& elements);
470 friend struct IValue;
471 template<class T_> friend List<T_> impl::toTypedList(List<IValue>);
472 template<class T_> friend List<IValue> impl::toList(List<T_>&&);
473 template<class T_> friend List<IValue> impl::toList(const List<T_>&);
474 friend const IValue* impl::ptr_to_first_element(const List<IValue>& list);
475};
476
477namespace impl {
478// GenericList is how IValue stores lists. It is, however, not part of the
479// public API. Kernels should use Lists with concrete types instead
480// (maybe except for some internal prim ops).
481using GenericList = List<IValue>;
482
483inline const IValue* ptr_to_first_element(const GenericList& list) {
484 return &list.impl_->list[0];
485}
486
487}
488}
489
490namespace torch {
491 template<class T> using List = c10::List<T>;
492}
493
494#include <ATen/core/List_inl.h> // IWYU pragma: keep
495