1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/runtime/container/array.h
22 * \brief Runtime Array container types.
23 */
24#ifndef TVM_RUNTIME_CONTAINER_ARRAY_H_
25#define TVM_RUNTIME_CONTAINER_ARRAY_H_
26
27#include <algorithm>
28#include <memory>
29#include <type_traits>
30#include <utility>
31#include <vector>
32
33#include "./base.h"
34#include "./optional.h"
35
36namespace tvm {
37namespace runtime {
38
39/*! \brief array node content in array */
40class ArrayNode : public Object, public InplaceArrayBase<ArrayNode, ObjectRef> {
41 public:
42 /*! \return The size of the array */
43 size_t size() const { return this->size_; }
44
45 /*!
46 * \brief Read i-th element from array.
47 * \param i The index
48 * \return the i-th element.
49 */
50 const ObjectRef at(int64_t i) const { return this->operator[](i); }
51
52 /*! \return begin constant iterator */
53 const ObjectRef* begin() const { return static_cast<ObjectRef*>(InplaceArrayBase::AddressOf(0)); }
54
55 /*! \return end constant iterator */
56 const ObjectRef* end() const { return begin() + size_; }
57
58 /*! \brief Release reference to all the elements */
59 void clear() { ShrinkBy(size_); }
60
61 /*!
62 * \brief Set i-th element of the array in-place
63 * \param i The index
64 * \param item The value to be set
65 */
66 void SetItem(int64_t i, ObjectRef item) { this->operator[](i) = std::move(item); }
67
68 /*!
69 * \brief Constructs a container and copy from another
70 * \param cap The capacity of the container
71 * \param from Source of the copy
72 * \return Ref-counted ArrayNode requested
73 */
74 static ObjectPtr<ArrayNode> CopyFrom(int64_t cap, ArrayNode* from) {
75 int64_t size = from->size_;
76 ICHECK_GE(cap, size) << "ValueError: not enough capacity";
77 ObjectPtr<ArrayNode> p = ArrayNode::Empty(cap);
78 ObjectRef* write = p->MutableBegin();
79 ObjectRef* read = from->MutableBegin();
80 // To ensure exception safety, size is only incremented after the initialization succeeds
81 for (int64_t& i = p->size_ = 0; i < size; ++i) {
82 new (write++) ObjectRef(*read++);
83 }
84 return p;
85 }
86
87 /*!
88 * \brief Constructs a container and move from another
89 * \param cap The capacity of the container
90 * \param from Source of the move
91 * \return Ref-counted ArrayNode requested
92 */
93 static ObjectPtr<ArrayNode> MoveFrom(int64_t cap, ArrayNode* from) {
94 int64_t size = from->size_;
95 ICHECK_GE(cap, size) << "ValueError: not enough capacity";
96 ObjectPtr<ArrayNode> p = ArrayNode::Empty(cap);
97 ObjectRef* write = p->MutableBegin();
98 ObjectRef* read = from->MutableBegin();
99 // To ensure exception safety, size is only incremented after the initialization succeeds
100 for (int64_t& i = p->size_ = 0; i < size; ++i) {
101 new (write++) ObjectRef(std::move(*read++));
102 }
103 from->size_ = 0;
104 return p;
105 }
106
107 /*!
108 * \brief Constructs a container with n elements. Each element is a copy of val
109 * \param n The size of the container
110 * \param val The init value
111 * \return Ref-counted ArrayNode requested
112 */
113 static ObjectPtr<ArrayNode> CreateRepeated(int64_t n, const ObjectRef& val) {
114 ObjectPtr<ArrayNode> p = ArrayNode::Empty(n);
115 ObjectRef* itr = p->MutableBegin();
116 for (int64_t& i = p->size_ = 0; i < n; ++i) {
117 new (itr++) ObjectRef(val);
118 }
119 return p;
120 }
121
122 static constexpr const uint32_t _type_index = TypeIndex::kRuntimeArray;
123 static constexpr const char* _type_key = "Array";
124 TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object);
125
126 private:
127 /*! \return Size of initialized memory, used by InplaceArrayBase. */
128 size_t GetSize() const { return this->size_; }
129
130 /*! \return begin mutable iterator */
131 ObjectRef* MutableBegin() const {
132 return static_cast<ObjectRef*>(InplaceArrayBase::AddressOf(0));
133 }
134
135 /*! \return end mutable iterator */
136 ObjectRef* MutableEnd() const { return MutableBegin() + size_; }
137
138 /*!
139 * \brief Create an ArrayNode with the given capacity.
140 * \param n Required capacity
141 * \return Ref-counted ArrayNode requested
142 */
143 static ObjectPtr<ArrayNode> Empty(int64_t n = kInitSize) {
144 ICHECK_GE(n, 0);
145 ObjectPtr<ArrayNode> p = make_inplace_array_object<ArrayNode, ObjectRef>(n);
146 p->capacity_ = n;
147 p->size_ = 0;
148 return p;
149 }
150
151 /*!
152 * \brief Inplace-initialize the elements starting idx from [first, last)
153 * \param idx The starting point
154 * \param first Begin of iterator
155 * \param last End of iterator
156 * \tparam IterType The type of iterator
157 * \return Self
158 */
159 template <typename IterType>
160 ArrayNode* InitRange(int64_t idx, IterType first, IterType last) {
161 ObjectRef* itr = MutableBegin() + idx;
162 for (; first != last; ++first) {
163 ObjectRef ref = *first;
164 new (itr++) ObjectRef(std::move(ref));
165 }
166 return this;
167 }
168
169 /*!
170 * \brief Move elements from right to left, requires src_begin > dst
171 * \param dst Destination
172 * \param src_begin The start point of copy (inclusive)
173 * \param src_end The end point of copy (exclusive)
174 * \return Self
175 */
176 ArrayNode* MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) {
177 ObjectRef* from = MutableBegin() + src_begin;
178 ObjectRef* to = MutableBegin() + dst;
179 while (src_begin++ != src_end) {
180 *to++ = std::move(*from++);
181 }
182 return this;
183 }
184
185 /*!
186 * \brief Move elements from left to right, requires src_begin < dst
187 * \param dst Destination
188 * \param src_begin The start point of move (inclusive)
189 * \param src_end The end point of move (exclusive)
190 * \return Self
191 */
192 ArrayNode* MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end) {
193 ObjectRef* from = MutableBegin() + src_end;
194 ObjectRef* to = MutableBegin() + (src_end - src_begin + dst);
195 while (src_begin++ != src_end) {
196 *--to = std::move(*--from);
197 }
198 return this;
199 }
200
201 /*!
202 * \brief Enlarges the size of the array
203 * \param delta Size enlarged, should be positive
204 * \param val Default value
205 * \return Self
206 */
207 ArrayNode* EnlargeBy(int64_t delta, const ObjectRef& val = ObjectRef(nullptr)) {
208 ObjectRef* itr = MutableEnd();
209 while (delta-- > 0) {
210 new (itr++) ObjectRef(val);
211 ++size_;
212 }
213 return this;
214 }
215
216 /*!
217 * \brief Shrinks the size of the array
218 * \param delta Size shrinked, should be positive
219 * \return Self
220 */
221 ArrayNode* ShrinkBy(int64_t delta) {
222 ObjectRef* itr = MutableEnd();
223 while (delta-- > 0) {
224 (--itr)->ObjectRef::~ObjectRef();
225 --size_;
226 }
227 return this;
228 }
229
230 /*! \brief Number of elements used */
231 int64_t size_;
232
233 /*! \brief Number of elements allocated */
234 int64_t capacity_;
235
236 /*! \brief Initial size of ArrayNode */
237 static constexpr int64_t kInitSize = 4;
238
239 /*! \brief Expansion factor of the Array */
240 static constexpr int64_t kIncFactor = 2;
241
242 // CRTP parent class
243 friend InplaceArrayBase<ArrayNode, ObjectRef>;
244
245 // Reference class
246 template <typename, typename>
247 friend class Array;
248
249 // To specialize make_object<ArrayNode>
250 friend ObjectPtr<ArrayNode> make_object<>();
251};
252
253/*! \brief Helper struct for type-checking
254 *
255 * is_valid_iterator<T,IterType>::value will be true if IterType can
256 * be dereferenced into a type that can be stored in an Array<T>, and
257 * false otherwise.
258 */
259template <typename T, typename IterType>
260struct is_valid_iterator
261 : std::bool_constant<std::is_base_of_v<
262 T, std::remove_cv_t<std::remove_reference_t<decltype(*std::declval<IterType>())>>>> {};
263
264template <typename T, typename IterType>
265struct is_valid_iterator<Optional<T>, IterType> : is_valid_iterator<T, IterType> {};
266
267template <typename T, typename IterType>
268inline constexpr bool is_valid_iterator_v = is_valid_iterator<T, IterType>::value;
269
270/*!
271 * \brief Array, container representing a contiguous sequence of ObjectRefs.
272 *
273 * Array implements in-place copy-on-write semantics.
274 *
275 * As in typical copy-on-write, a method which would typically mutate the array
276 * instead opaquely copies the underlying container, and then acts on its copy.
277 *
278 * If the array has reference count equal to one, we directly update the
279 * container in place without copying. This is optimization is sound because
280 * when the reference count is equal to one this reference is guranteed to be
281 * the sole pointer to the container.
282 *
283 *
284 * operator[] only provides const access, use Set to mutate the content.
285 * \tparam T The content ObjectRef type.
286 */
287template <typename T,
288 typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>
289class Array : public ObjectRef {
290 public:
291 using value_type = T;
292 // constructors
293 /*!
294 * \brief default constructor
295 */
296 Array() { data_ = ArrayNode::Empty(); }
297
298 /*!
299 * \brief move constructor
300 * \param other source
301 */
302 Array(Array<T>&& other) : ObjectRef() { // NOLINT(*)
303 data_ = std::move(other.data_);
304 }
305
306 /*!
307 * \brief copy constructor
308 * \param other source
309 */
310 Array(const Array<T>& other) : ObjectRef() { // NOLINT(*)
311 data_ = other.data_;
312 }
313
314 /*!
315 * \brief constructor from pointer
316 * \param n the container pointer
317 */
318 explicit Array(ObjectPtr<Object> n) : ObjectRef(n) {}
319
320 /*!
321 * \brief Constructor from iterator
322 * \param first begin of iterator
323 * \param last end of iterator
324 * \tparam IterType The type of iterator
325 */
326 template <typename IterType>
327 Array(IterType first, IterType last) {
328 static_assert(is_valid_iterator_v<T, IterType>,
329 "IterType cannot be inserted into a tvm::Array<T>");
330 Assign(first, last);
331 }
332
333 /*!
334 * \brief constructor from initializer list
335 * \param init The initializer list
336 */
337 Array(std::initializer_list<T> init) { // NOLINT(*)
338 Assign(init.begin(), init.end());
339 }
340
341 /*!
342 * \brief constructor from vector
343 * \param init The vector
344 */
345 Array(const std::vector<T>& init) { // NOLINT(*)
346 Assign(init.begin(), init.end());
347 }
348
349 /*!
350 * \brief Constructs a container with n elements. Each element is a copy of val
351 * \param n The size of the container
352 * \param val The init value
353 */
354 explicit Array(const size_t n, const T& val) { data_ = ArrayNode::CreateRepeated(n, val); }
355
356 /*!
357 * \brief move assign operator
358 * \param other The source of assignment
359 * \return reference to self.
360 */
361 Array<T>& operator=(Array<T>&& other) {
362 data_ = std::move(other.data_);
363 return *this;
364 }
365
366 /*!
367 * \brief copy assign operator
368 * \param other The source of assignment
369 * \return reference to self.
370 */
371 Array<T>& operator=(const Array<T>& other) {
372 data_ = other.data_;
373 return *this;
374 }
375
376 public:
377 // iterators
378 struct ValueConverter {
379 using ResultType = T;
380 static T convert(const ObjectRef& n) { return DowncastNoCheck<T>(n); }
381 };
382
383 using iterator = IterAdapter<ValueConverter, const ObjectRef*>;
384 using reverse_iterator = ReverseIterAdapter<ValueConverter, const ObjectRef*>;
385
386 /*! \return begin iterator */
387 iterator begin() const { return iterator(GetArrayNode()->begin()); }
388
389 /*! \return end iterator */
390 iterator end() const { return iterator(GetArrayNode()->end()); }
391
392 /*! \return rbegin iterator */
393 reverse_iterator rbegin() const {
394 // ArrayNode::end() is never nullptr
395 return reverse_iterator(GetArrayNode()->end() - 1);
396 }
397
398 /*! \return rend iterator */
399 reverse_iterator rend() const {
400 // ArrayNode::begin() is never nullptr
401 return reverse_iterator(GetArrayNode()->begin() - 1);
402 }
403
404 public:
405 // const methods in std::vector
406 /*!
407 * \brief Immutably read i-th element from array.
408 * \param i The index
409 * \return the i-th element.
410 */
411 const T operator[](int64_t i) const {
412 ArrayNode* p = GetArrayNode();
413 ICHECK(p != nullptr) << "ValueError: cannot index a null array";
414 ICHECK(0 <= i && i < p->size_)
415 << "IndexError: indexing " << i << " on an array of size " << p->size_;
416 return DowncastNoCheck<T>(*(p->begin() + i));
417 }
418
419 /*! \return The size of the array */
420 size_t size() const {
421 ArrayNode* p = GetArrayNode();
422 return p == nullptr ? 0 : GetArrayNode()->size_;
423 }
424
425 /*! \return The capacity of the array */
426 size_t capacity() const {
427 ArrayNode* p = GetArrayNode();
428 return p == nullptr ? 0 : GetArrayNode()->capacity_;
429 }
430
431 /*! \return Whether array is empty */
432 bool empty() const { return size() == 0; }
433
434 /*! \return The first element of the array */
435 const T front() const {
436 ArrayNode* p = GetArrayNode();
437 ICHECK(p != nullptr) << "ValueError: cannot index a null array";
438 ICHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array";
439 return DowncastNoCheck<T>(*(p->begin()));
440 }
441
442 /*! \return The last element of the array */
443 const T back() const {
444 ArrayNode* p = GetArrayNode();
445 ICHECK(p != nullptr) << "ValueError: cannot index a null array";
446 ICHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array";
447 return DowncastNoCheck<T>(*(p->end() - 1));
448 }
449
450 public:
451 // mutation in std::vector, implements copy-on-write
452
453 /*!
454 * \brief push a new item to the back of the list
455 * \param item The item to be pushed.
456 */
457 void push_back(const T& item) {
458 ArrayNode* p = CopyOnWrite(1);
459 p->EmplaceInit(p->size_++, item);
460 }
461
462 /*!
463 * \brief Insert an element into the given position
464 * \param position An iterator pointing to the insertion point
465 * \param val The element to insert
466 */
467 void insert(iterator position, const T& val) {
468 ICHECK(data_ != nullptr) << "ValueError: cannot insert a null array";
469 int64_t idx = std::distance(begin(), position);
470 int64_t size = GetArrayNode()->size_;
471 auto addr = CopyOnWrite(1) //
472 ->EnlargeBy(1) //
473 ->MoveElementsRight(idx + 1, idx, size) //
474 ->MutableBegin();
475 new (addr + idx) ObjectRef(val);
476 }
477
478 /*!
479 * \brief Insert a range of elements into the given position
480 * \param position An iterator pointing to the insertion point
481 * \param first The begin iterator of the range
482 * \param last The end iterator of the range
483 */
484 template <typename IterType>
485 void insert(iterator position, IterType first, IterType last) {
486 static_assert(is_valid_iterator_v<T, IterType>,
487 "IterType cannot be inserted into a tvm::Array<T>");
488
489 if (first == last) {
490 return;
491 }
492 ICHECK(data_ != nullptr) << "ValueError: cannot insert a null array";
493 int64_t idx = std::distance(begin(), position);
494 int64_t size = GetArrayNode()->size_;
495 int64_t numel = std::distance(first, last);
496 CopyOnWrite(numel)
497 ->EnlargeBy(numel)
498 ->MoveElementsRight(idx + numel, idx, size)
499 ->InitRange(idx, first, last);
500 }
501
502 /*! \brief Remove the last item of the list */
503 void pop_back() {
504 ICHECK(data_ != nullptr) << "ValueError: cannot pop_back because array is null";
505 int64_t size = GetArrayNode()->size_;
506 ICHECK_GT(size, 0) << "ValueError: cannot pop_back because array is empty";
507 CopyOnWrite()->ShrinkBy(1);
508 }
509
510 /*!
511 * \brief Erase an element on the given position
512 * \param position An iterator pointing to the element to be erased
513 */
514 void erase(iterator position) {
515 ICHECK(data_ != nullptr) << "ValueError: cannot erase a null array";
516 int64_t st = std::distance(begin(), position);
517 int64_t size = GetArrayNode()->size_;
518 ICHECK(0 <= st && st < size) << "ValueError: cannot erase at index " << st
519 << ", because Array size is " << size;
520 CopyOnWrite() //
521 ->MoveElementsLeft(st, st + 1, size) //
522 ->ShrinkBy(1);
523 }
524
525 /*!
526 * \brief Erase a given range of elements
527 * \param first The begin iterator of the range
528 * \param last The end iterator of the range
529 */
530 void erase(iterator first, iterator last) {
531 if (first == last) {
532 return;
533 }
534 ICHECK(data_ != nullptr) << "ValueError: cannot erase a null array";
535 int64_t size = GetArrayNode()->size_;
536 int64_t st = std::distance(begin(), first);
537 int64_t ed = std::distance(begin(), last);
538 ICHECK_LT(st, ed) << "ValueError: cannot erase array in range [" << st << ", " << ed << ")";
539 ICHECK(0 <= st && st <= size && 0 <= ed && ed <= size)
540 << "ValueError: cannot erase array in range [" << st << ", " << ed << ")"
541 << ", because array size is " << size;
542 CopyOnWrite() //
543 ->MoveElementsLeft(st, ed, size) //
544 ->ShrinkBy(ed - st);
545 }
546
547 /*!
548 * \brief Resize the array.
549 * \param n The new size.
550 */
551 void resize(int64_t n) {
552 ICHECK_GE(n, 0) << "ValueError: cannot resize an Array to negative size";
553 if (data_ == nullptr) {
554 SwitchContainer(n);
555 return;
556 }
557 int64_t size = GetArrayNode()->size_;
558 if (size < n) {
559 CopyOnWrite(n - size)->EnlargeBy(n - size);
560 } else if (size > n) {
561 CopyOnWrite()->ShrinkBy(size - n);
562 }
563 }
564
565 /*!
566 * \brief Make sure the list has the capacity of at least n
567 * \param n lower bound of the capacity
568 */
569 void reserve(int64_t n) {
570 if (data_ == nullptr || n > GetArrayNode()->capacity_) {
571 SwitchContainer(n);
572 }
573 }
574
575 /*! \brief Release reference to all the elements */
576 void clear() {
577 if (data_ != nullptr) {
578 ArrayNode* p = CopyOnWrite();
579 p->clear();
580 }
581 }
582
583 public:
584 // Array's own methods
585
586 /*!
587 * \brief set i-th element of the array.
588 * \param i The index
589 * \param value The value to be setted.
590 */
591 void Set(int64_t i, T value) {
592 ArrayNode* p = this->CopyOnWrite();
593 ICHECK(0 <= i && i < p->size_)
594 << "IndexError: indexing " << i << " on an array of size " << p->size_;
595 *(p->MutableBegin() + i) = std::move(value);
596 }
597
598 /*! \return The underlying ArrayNode */
599 ArrayNode* GetArrayNode() const { return static_cast<ArrayNode*>(data_.get()); }
600
601 /*!
602 * \brief Helper function to apply a map function onto the array.
603 *
604 * \param fmap The transformation function T -> U.
605 *
606 * \tparam F The type of the mutation function.
607 *
608 * \tparam U The type of the returned array, inferred from the
609 * return type of F. If overridden by the user, must be something
610 * that is convertible from the return type of F.
611 *
612 * \note This function performs copy on write optimization. If
613 * `fmap` returns an object of type `T`, and all elements of the
614 * array are mapped to themselves, then the returned array will be
615 * the same as the original, and reference counts of the elements in
616 * the array will not be incremented.
617 *
618 * \return The transformed array.
619 */
620 template <typename F, typename U = std::invoke_result_t<F, T>>
621 Array<U> Map(F fmap) const {
622 return Array<U>(MapHelper(data_, fmap));
623 }
624
625 /*!
626 * \brief Helper function to apply fmutate to mutate an array.
627 * \param fmutate The transformation function T -> T.
628 * \tparam F the type of the mutation function.
629 * \note This function performs copy on write optimization.
630 */
631 template <typename F, typename = std::enable_if_t<std::is_same_v<T, std::invoke_result_t<F, T>>>>
632 void MutateByApply(F fmutate) {
633 data_ = MapHelper(std::move(data_), fmutate);
634 }
635
636 /*!
637 * \brief reset the array to content from iterator.
638 * \param first begin of iterator
639 * \param last end of iterator
640 * \tparam IterType The type of iterator
641 */
642 template <typename IterType>
643 void Assign(IterType first, IterType last) {
644 int64_t cap = std::distance(first, last);
645 ICHECK_GE(cap, 0) << "ValueError: cannot construct an Array of negative size";
646 ArrayNode* p = GetArrayNode();
647 if (p != nullptr && data_.unique() && p->capacity_ >= cap) {
648 // do not have to make new space
649 p->clear();
650 } else {
651 // create new space
652 data_ = ArrayNode::Empty(cap);
653 p = GetArrayNode();
654 }
655 // To ensure exception safety, size is only incremented after the initialization succeeds
656 ObjectRef* itr = p->MutableBegin();
657 for (int64_t& i = p->size_ = 0; i < cap; ++i, ++first, ++itr) {
658 new (itr) ObjectRef(*first);
659 }
660 }
661
662 /*!
663 * \brief Copy on write semantics
664 * Do nothing if current handle is the unique copy of the array.
665 * Otherwise make a new copy of the array to ensure the current handle
666 * hold a unique copy.
667 *
668 * \return Handle to the internal node container(which ganrantees to be unique)
669 */
670 ArrayNode* CopyOnWrite() {
671 if (data_ == nullptr) {
672 return SwitchContainer(ArrayNode::kInitSize);
673 }
674 if (!data_.unique()) {
675 return SwitchContainer(capacity());
676 }
677 return static_cast<ArrayNode*>(data_.get());
678 }
679
680 /*! \brief specify container node */
681 using ContainerType = ArrayNode;
682
683 private:
684 /*!
685 * \brief Implement copy-on-write semantics, and ensures capacity is enough for extra elements.
686 * \param reserve_extra Number of extra slots needed
687 * \return ArrayNode pointer to the unique copy
688 */
689 ArrayNode* CopyOnWrite(int64_t reserve_extra) {
690 ArrayNode* p = GetArrayNode();
691 if (p == nullptr) {
692 // necessary to get around the constexpr address issue before c++17
693 const int64_t kInitSize = ArrayNode::kInitSize;
694 return SwitchContainer(std::max(kInitSize, reserve_extra));
695 }
696 if (p->capacity_ >= p->size_ + reserve_extra) {
697 return CopyOnWrite();
698 }
699 int64_t cap = p->capacity_ * ArrayNode::kIncFactor;
700 cap = std::max(cap, p->size_ + reserve_extra);
701 return SwitchContainer(cap);
702 }
703
704 /*!
705 * \brief Move or copy the ArrayNode to new address with the given capacity
706 * \param capacity The capacity requirement of the new address
707 */
708 ArrayNode* SwitchContainer(int64_t capacity) {
709 if (data_ == nullptr) {
710 data_ = ArrayNode::Empty(capacity);
711 } else if (data_.unique()) {
712 data_ = ArrayNode::MoveFrom(capacity, GetArrayNode());
713 } else {
714 data_ = ArrayNode::CopyFrom(capacity, GetArrayNode());
715 }
716 return static_cast<ArrayNode*>(data_.get());
717 }
718
719 /*! \brief Helper method for mutate/map
720 *
721 * A helper function used internally by both `Array::Map` and
722 * `Array::MutateInPlace`. Given an array of data, apply the
723 * mapping function to each element, returning the collected array.
724 * Applies both mutate-in-place and copy-on-write optimizations, if
725 * possible.
726 *
727 * \param data A pointer to the ArrayNode containing input data.
728 * Passed by value to allow for mutate-in-place optimizations.
729 *
730 * \param fmap The mapping function
731 *
732 * \tparam F The type of the mutation function.
733 *
734 * \tparam U The output type of the mutation function. Inferred
735 * from the callable type given. Must inherit from ObjectRef.
736 *
737 * \return The mapped array. Depending on whether mutate-in-place
738 * or copy-on-write optimizations were applicable, may be the same
739 * underlying array as the `data` parameter.
740 */
741 template <typename F, typename U = std::invoke_result_t<F, T>>
742 static ObjectPtr<Object> MapHelper(ObjectPtr<Object> data, F fmap) {
743 if (data == nullptr) {
744 return nullptr;
745 }
746
747 ICHECK(data->IsInstance<ArrayNode>());
748
749 constexpr bool is_same_output_type = std::is_same_v<T, U>;
750
751 if constexpr (is_same_output_type) {
752 if (data.unique()) {
753 // Mutate-in-place path. Only allowed if the output type U is
754 // the same as type T, we have a mutable this*, and there are
755 // no other shared copies of the array.
756 auto arr = static_cast<ArrayNode*>(data.get());
757 for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) {
758 T mapped = fmap(DowncastNoCheck<T>(std::move(*it)));
759 *it = std::move(mapped);
760 }
761 return data;
762 }
763 }
764
765 constexpr bool compatible_types = is_valid_iterator_v<T, U*> || is_valid_iterator_v<U, T*>;
766
767 ObjectPtr<ArrayNode> output = nullptr;
768 auto arr = static_cast<ArrayNode*>(data.get());
769
770 auto it = arr->begin();
771 if constexpr (compatible_types) {
772 // Copy-on-write path, if the output Array<U> might be
773 // represented by the same underlying array as the existing
774 // Array<T>. Typically, this is for functions that map `T` to
775 // `T`, but can also apply to functions that map `T` to
776 // `Optional<T>`, or that map `T` to a subclass or superclass of
777 // `T`.
778 bool all_identical = true;
779 for (; it != arr->end(); it++) {
780 U mapped = fmap(DowncastNoCheck<T>(*it));
781 if (!mapped.same_as(*it)) {
782 // At least one mapped element is different than the
783 // original. Therefore, prepare the output array,
784 // consisting of any previous elements that had mapped to
785 // themselves (if any), and the element that didn't map to
786 // itself.
787 all_identical = false;
788 output = ArrayNode::CreateRepeated(arr->size(), U());
789 output->InitRange(0, arr->begin(), it);
790 output->SetItem(it - arr->begin(), std::move(mapped));
791 it++;
792 break;
793 }
794 }
795 if (all_identical) {
796 return data;
797 }
798 } else {
799 // Path for incompatible types. The constexpr check for
800 // compatible types isn't strictly necessary, as the first
801 // mapped.same_as(*it) would return false, but we might as well
802 // avoid it altogether.
803 output = ArrayNode::CreateRepeated(arr->size(), U());
804 }
805
806 // Normal path for incompatible types, or post-copy path for
807 // copy-on-write instances.
808 //
809 // If the types are incompatible, then at this point `output` is
810 // empty, and `it` points to the first element of the input.
811 //
812 // If the types were compatible, then at this point `output`
813 // contains zero or more elements that mapped to themselves
814 // followed by the first element that does not map to itself, and
815 // `it` points to the element just after the first element that
816 // does not map to itself. Because at least one element has been
817 // changed, we no longer have the opportunity to avoid a copy, so
818 // we don't need to check the result.
819 //
820 // In both cases, `it` points to the next element to be processed,
821 // so we can either start or resume the iteration from that point,
822 // with no further checks on the result.
823 for (; it != arr->end(); it++) {
824 U mapped = fmap(DowncastNoCheck<T>(*it));
825 output->SetItem(it - arr->begin(), std::move(mapped));
826 }
827
828 return output;
829 }
830};
831
832/*!
833 * \brief Concat two Arrays.
834 * \param lhs first Array to be concatenated.
835 * \param rhs second Array to be concatenated.
836 * \return The concatenated Array. Original Arrays are kept unchanged.
837 */
838template <typename T,
839 typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>
840inline Array<T> Concat(Array<T> lhs, const Array<T>& rhs) {
841 for (const auto& x : rhs) {
842 lhs.push_back(x);
843 }
844 return std::move(lhs);
845}
846
847// Specialize make_object<ArrayNode> to make sure it is correct.
848template <>
849inline ObjectPtr<ArrayNode> make_object() {
850 return ArrayNode::Empty();
851}
852
853} // namespace runtime
854
855// expose the functions to the root namespace.
856using runtime::Array;
857using runtime::ArrayNode;
858} // namespace tvm
859
860#endif // TVM_RUNTIME_CONTAINER_ARRAY_H_
861