1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/runtime/container/array.h |
22 | * \brief Runtime Array container types. |
23 | */ |
24 | #ifndef TVM_RUNTIME_CONTAINER_ARRAY_H_ |
25 | #define TVM_RUNTIME_CONTAINER_ARRAY_H_ |
26 | |
27 | #include <algorithm> |
28 | #include <memory> |
29 | #include <type_traits> |
30 | #include <utility> |
31 | #include <vector> |
32 | |
33 | #include "./base.h" |
34 | #include "./optional.h" |
35 | |
36 | namespace tvm { |
37 | namespace runtime { |
38 | |
39 | /*! \brief array node content in array */ |
40 | class ArrayNode : public Object, public InplaceArrayBase<ArrayNode, ObjectRef> { |
41 | public: |
42 | /*! \return The size of the array */ |
43 | size_t size() const { return this->size_; } |
44 | |
45 | /*! |
46 | * \brief Read i-th element from array. |
47 | * \param i The index |
48 | * \return the i-th element. |
49 | */ |
50 | const ObjectRef at(int64_t i) const { return this->operator[](i); } |
51 | |
52 | /*! \return begin constant iterator */ |
53 | const ObjectRef* begin() const { return static_cast<ObjectRef*>(InplaceArrayBase::AddressOf(0)); } |
54 | |
55 | /*! \return end constant iterator */ |
56 | const ObjectRef* end() const { return begin() + size_; } |
57 | |
58 | /*! \brief Release reference to all the elements */ |
59 | void clear() { ShrinkBy(size_); } |
60 | |
61 | /*! |
62 | * \brief Set i-th element of the array in-place |
63 | * \param i The index |
64 | * \param item The value to be set |
65 | */ |
66 | void SetItem(int64_t i, ObjectRef item) { this->operator[](i) = std::move(item); } |
67 | |
68 | /*! |
69 | * \brief Constructs a container and copy from another |
70 | * \param cap The capacity of the container |
71 | * \param from Source of the copy |
72 | * \return Ref-counted ArrayNode requested |
73 | */ |
74 | static ObjectPtr<ArrayNode> CopyFrom(int64_t cap, ArrayNode* from) { |
75 | int64_t size = from->size_; |
76 | ICHECK_GE(cap, size) << "ValueError: not enough capacity" ; |
77 | ObjectPtr<ArrayNode> p = ArrayNode::Empty(cap); |
78 | ObjectRef* write = p->MutableBegin(); |
79 | ObjectRef* read = from->MutableBegin(); |
80 | // To ensure exception safety, size is only incremented after the initialization succeeds |
81 | for (int64_t& i = p->size_ = 0; i < size; ++i) { |
82 | new (write++) ObjectRef(*read++); |
83 | } |
84 | return p; |
85 | } |
86 | |
87 | /*! |
88 | * \brief Constructs a container and move from another |
89 | * \param cap The capacity of the container |
90 | * \param from Source of the move |
91 | * \return Ref-counted ArrayNode requested |
92 | */ |
93 | static ObjectPtr<ArrayNode> MoveFrom(int64_t cap, ArrayNode* from) { |
94 | int64_t size = from->size_; |
95 | ICHECK_GE(cap, size) << "ValueError: not enough capacity" ; |
96 | ObjectPtr<ArrayNode> p = ArrayNode::Empty(cap); |
97 | ObjectRef* write = p->MutableBegin(); |
98 | ObjectRef* read = from->MutableBegin(); |
99 | // To ensure exception safety, size is only incremented after the initialization succeeds |
100 | for (int64_t& i = p->size_ = 0; i < size; ++i) { |
101 | new (write++) ObjectRef(std::move(*read++)); |
102 | } |
103 | from->size_ = 0; |
104 | return p; |
105 | } |
106 | |
107 | /*! |
108 | * \brief Constructs a container with n elements. Each element is a copy of val |
109 | * \param n The size of the container |
110 | * \param val The init value |
111 | * \return Ref-counted ArrayNode requested |
112 | */ |
113 | static ObjectPtr<ArrayNode> CreateRepeated(int64_t n, const ObjectRef& val) { |
114 | ObjectPtr<ArrayNode> p = ArrayNode::Empty(n); |
115 | ObjectRef* itr = p->MutableBegin(); |
116 | for (int64_t& i = p->size_ = 0; i < n; ++i) { |
117 | new (itr++) ObjectRef(val); |
118 | } |
119 | return p; |
120 | } |
121 | |
122 | static constexpr const uint32_t _type_index = TypeIndex::kRuntimeArray; |
123 | static constexpr const char* _type_key = "Array" ; |
124 | TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object); |
125 | |
126 | private: |
127 | /*! \return Size of initialized memory, used by InplaceArrayBase. */ |
128 | size_t GetSize() const { return this->size_; } |
129 | |
130 | /*! \return begin mutable iterator */ |
131 | ObjectRef* MutableBegin() const { |
132 | return static_cast<ObjectRef*>(InplaceArrayBase::AddressOf(0)); |
133 | } |
134 | |
135 | /*! \return end mutable iterator */ |
136 | ObjectRef* MutableEnd() const { return MutableBegin() + size_; } |
137 | |
138 | /*! |
139 | * \brief Create an ArrayNode with the given capacity. |
140 | * \param n Required capacity |
141 | * \return Ref-counted ArrayNode requested |
142 | */ |
143 | static ObjectPtr<ArrayNode> Empty(int64_t n = kInitSize) { |
144 | ICHECK_GE(n, 0); |
145 | ObjectPtr<ArrayNode> p = make_inplace_array_object<ArrayNode, ObjectRef>(n); |
146 | p->capacity_ = n; |
147 | p->size_ = 0; |
148 | return p; |
149 | } |
150 | |
151 | /*! |
152 | * \brief Inplace-initialize the elements starting idx from [first, last) |
153 | * \param idx The starting point |
154 | * \param first Begin of iterator |
155 | * \param last End of iterator |
156 | * \tparam IterType The type of iterator |
157 | * \return Self |
158 | */ |
159 | template <typename IterType> |
160 | ArrayNode* InitRange(int64_t idx, IterType first, IterType last) { |
161 | ObjectRef* itr = MutableBegin() + idx; |
162 | for (; first != last; ++first) { |
163 | ObjectRef ref = *first; |
164 | new (itr++) ObjectRef(std::move(ref)); |
165 | } |
166 | return this; |
167 | } |
168 | |
169 | /*! |
170 | * \brief Move elements from right to left, requires src_begin > dst |
171 | * \param dst Destination |
172 | * \param src_begin The start point of copy (inclusive) |
173 | * \param src_end The end point of copy (exclusive) |
174 | * \return Self |
175 | */ |
176 | ArrayNode* MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) { |
177 | ObjectRef* from = MutableBegin() + src_begin; |
178 | ObjectRef* to = MutableBegin() + dst; |
179 | while (src_begin++ != src_end) { |
180 | *to++ = std::move(*from++); |
181 | } |
182 | return this; |
183 | } |
184 | |
185 | /*! |
186 | * \brief Move elements from left to right, requires src_begin < dst |
187 | * \param dst Destination |
188 | * \param src_begin The start point of move (inclusive) |
189 | * \param src_end The end point of move (exclusive) |
190 | * \return Self |
191 | */ |
192 | ArrayNode* MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end) { |
193 | ObjectRef* from = MutableBegin() + src_end; |
194 | ObjectRef* to = MutableBegin() + (src_end - src_begin + dst); |
195 | while (src_begin++ != src_end) { |
196 | *--to = std::move(*--from); |
197 | } |
198 | return this; |
199 | } |
200 | |
201 | /*! |
202 | * \brief Enlarges the size of the array |
203 | * \param delta Size enlarged, should be positive |
204 | * \param val Default value |
205 | * \return Self |
206 | */ |
207 | ArrayNode* EnlargeBy(int64_t delta, const ObjectRef& val = ObjectRef(nullptr)) { |
208 | ObjectRef* itr = MutableEnd(); |
209 | while (delta-- > 0) { |
210 | new (itr++) ObjectRef(val); |
211 | ++size_; |
212 | } |
213 | return this; |
214 | } |
215 | |
216 | /*! |
217 | * \brief Shrinks the size of the array |
218 | * \param delta Size shrinked, should be positive |
219 | * \return Self |
220 | */ |
221 | ArrayNode* ShrinkBy(int64_t delta) { |
222 | ObjectRef* itr = MutableEnd(); |
223 | while (delta-- > 0) { |
224 | (--itr)->ObjectRef::~ObjectRef(); |
225 | --size_; |
226 | } |
227 | return this; |
228 | } |
229 | |
230 | /*! \brief Number of elements used */ |
231 | int64_t size_; |
232 | |
233 | /*! \brief Number of elements allocated */ |
234 | int64_t capacity_; |
235 | |
236 | /*! \brief Initial size of ArrayNode */ |
237 | static constexpr int64_t kInitSize = 4; |
238 | |
239 | /*! \brief Expansion factor of the Array */ |
240 | static constexpr int64_t kIncFactor = 2; |
241 | |
242 | // CRTP parent class |
243 | friend InplaceArrayBase<ArrayNode, ObjectRef>; |
244 | |
245 | // Reference class |
246 | template <typename, typename> |
247 | friend class Array; |
248 | |
249 | // To specialize make_object<ArrayNode> |
250 | friend ObjectPtr<ArrayNode> make_object<>(); |
251 | }; |
252 | |
253 | /*! \brief Helper struct for type-checking |
254 | * |
255 | * is_valid_iterator<T,IterType>::value will be true if IterType can |
256 | * be dereferenced into a type that can be stored in an Array<T>, and |
257 | * false otherwise. |
258 | */ |
259 | template <typename T, typename IterType> |
260 | struct is_valid_iterator |
261 | : std::bool_constant<std::is_base_of_v< |
262 | T, std::remove_cv_t<std::remove_reference_t<decltype(*std::declval<IterType>())>>>> {}; |
263 | |
264 | template <typename T, typename IterType> |
265 | struct is_valid_iterator<Optional<T>, IterType> : is_valid_iterator<T, IterType> {}; |
266 | |
267 | template <typename T, typename IterType> |
268 | inline constexpr bool is_valid_iterator_v = is_valid_iterator<T, IterType>::value; |
269 | |
270 | /*! |
271 | * \brief Array, container representing a contiguous sequence of ObjectRefs. |
272 | * |
273 | * Array implements in-place copy-on-write semantics. |
274 | * |
275 | * As in typical copy-on-write, a method which would typically mutate the array |
276 | * instead opaquely copies the underlying container, and then acts on its copy. |
277 | * |
278 | * If the array has reference count equal to one, we directly update the |
279 | * container in place without copying. This is optimization is sound because |
280 | * when the reference count is equal to one this reference is guranteed to be |
281 | * the sole pointer to the container. |
282 | * |
283 | * |
284 | * operator[] only provides const access, use Set to mutate the content. |
285 | * \tparam T The content ObjectRef type. |
286 | */ |
287 | template <typename T, |
288 | typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type> |
289 | class Array : public ObjectRef { |
290 | public: |
291 | using value_type = T; |
292 | // constructors |
293 | /*! |
294 | * \brief default constructor |
295 | */ |
296 | Array() { data_ = ArrayNode::Empty(); } |
297 | |
298 | /*! |
299 | * \brief move constructor |
300 | * \param other source |
301 | */ |
302 | Array(Array<T>&& other) : ObjectRef() { // NOLINT(*) |
303 | data_ = std::move(other.data_); |
304 | } |
305 | |
306 | /*! |
307 | * \brief copy constructor |
308 | * \param other source |
309 | */ |
310 | Array(const Array<T>& other) : ObjectRef() { // NOLINT(*) |
311 | data_ = other.data_; |
312 | } |
313 | |
314 | /*! |
315 | * \brief constructor from pointer |
316 | * \param n the container pointer |
317 | */ |
318 | explicit Array(ObjectPtr<Object> n) : ObjectRef(n) {} |
319 | |
320 | /*! |
321 | * \brief Constructor from iterator |
322 | * \param first begin of iterator |
323 | * \param last end of iterator |
324 | * \tparam IterType The type of iterator |
325 | */ |
326 | template <typename IterType> |
327 | Array(IterType first, IterType last) { |
328 | static_assert(is_valid_iterator_v<T, IterType>, |
329 | "IterType cannot be inserted into a tvm::Array<T>" ); |
330 | Assign(first, last); |
331 | } |
332 | |
333 | /*! |
334 | * \brief constructor from initializer list |
335 | * \param init The initializer list |
336 | */ |
337 | Array(std::initializer_list<T> init) { // NOLINT(*) |
338 | Assign(init.begin(), init.end()); |
339 | } |
340 | |
341 | /*! |
342 | * \brief constructor from vector |
343 | * \param init The vector |
344 | */ |
345 | Array(const std::vector<T>& init) { // NOLINT(*) |
346 | Assign(init.begin(), init.end()); |
347 | } |
348 | |
349 | /*! |
350 | * \brief Constructs a container with n elements. Each element is a copy of val |
351 | * \param n The size of the container |
352 | * \param val The init value |
353 | */ |
354 | explicit Array(const size_t n, const T& val) { data_ = ArrayNode::CreateRepeated(n, val); } |
355 | |
356 | /*! |
357 | * \brief move assign operator |
358 | * \param other The source of assignment |
359 | * \return reference to self. |
360 | */ |
361 | Array<T>& operator=(Array<T>&& other) { |
362 | data_ = std::move(other.data_); |
363 | return *this; |
364 | } |
365 | |
366 | /*! |
367 | * \brief copy assign operator |
368 | * \param other The source of assignment |
369 | * \return reference to self. |
370 | */ |
371 | Array<T>& operator=(const Array<T>& other) { |
372 | data_ = other.data_; |
373 | return *this; |
374 | } |
375 | |
376 | public: |
377 | // iterators |
378 | struct ValueConverter { |
379 | using ResultType = T; |
380 | static T convert(const ObjectRef& n) { return DowncastNoCheck<T>(n); } |
381 | }; |
382 | |
383 | using iterator = IterAdapter<ValueConverter, const ObjectRef*>; |
384 | using reverse_iterator = ReverseIterAdapter<ValueConverter, const ObjectRef*>; |
385 | |
386 | /*! \return begin iterator */ |
387 | iterator begin() const { return iterator(GetArrayNode()->begin()); } |
388 | |
389 | /*! \return end iterator */ |
390 | iterator end() const { return iterator(GetArrayNode()->end()); } |
391 | |
392 | /*! \return rbegin iterator */ |
393 | reverse_iterator rbegin() const { |
394 | // ArrayNode::end() is never nullptr |
395 | return reverse_iterator(GetArrayNode()->end() - 1); |
396 | } |
397 | |
398 | /*! \return rend iterator */ |
399 | reverse_iterator rend() const { |
400 | // ArrayNode::begin() is never nullptr |
401 | return reverse_iterator(GetArrayNode()->begin() - 1); |
402 | } |
403 | |
404 | public: |
405 | // const methods in std::vector |
406 | /*! |
407 | * \brief Immutably read i-th element from array. |
408 | * \param i The index |
409 | * \return the i-th element. |
410 | */ |
411 | const T operator[](int64_t i) const { |
412 | ArrayNode* p = GetArrayNode(); |
413 | ICHECK(p != nullptr) << "ValueError: cannot index a null array" ; |
414 | ICHECK(0 <= i && i < p->size_) |
415 | << "IndexError: indexing " << i << " on an array of size " << p->size_; |
416 | return DowncastNoCheck<T>(*(p->begin() + i)); |
417 | } |
418 | |
419 | /*! \return The size of the array */ |
420 | size_t size() const { |
421 | ArrayNode* p = GetArrayNode(); |
422 | return p == nullptr ? 0 : GetArrayNode()->size_; |
423 | } |
424 | |
425 | /*! \return The capacity of the array */ |
426 | size_t capacity() const { |
427 | ArrayNode* p = GetArrayNode(); |
428 | return p == nullptr ? 0 : GetArrayNode()->capacity_; |
429 | } |
430 | |
431 | /*! \return Whether array is empty */ |
432 | bool empty() const { return size() == 0; } |
433 | |
434 | /*! \return The first element of the array */ |
435 | const T front() const { |
436 | ArrayNode* p = GetArrayNode(); |
437 | ICHECK(p != nullptr) << "ValueError: cannot index a null array" ; |
438 | ICHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array" ; |
439 | return DowncastNoCheck<T>(*(p->begin())); |
440 | } |
441 | |
442 | /*! \return The last element of the array */ |
443 | const T back() const { |
444 | ArrayNode* p = GetArrayNode(); |
445 | ICHECK(p != nullptr) << "ValueError: cannot index a null array" ; |
446 | ICHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array" ; |
447 | return DowncastNoCheck<T>(*(p->end() - 1)); |
448 | } |
449 | |
450 | public: |
451 | // mutation in std::vector, implements copy-on-write |
452 | |
453 | /*! |
454 | * \brief push a new item to the back of the list |
455 | * \param item The item to be pushed. |
456 | */ |
457 | void push_back(const T& item) { |
458 | ArrayNode* p = CopyOnWrite(1); |
459 | p->EmplaceInit(p->size_++, item); |
460 | } |
461 | |
462 | /*! |
463 | * \brief Insert an element into the given position |
464 | * \param position An iterator pointing to the insertion point |
465 | * \param val The element to insert |
466 | */ |
467 | void insert(iterator position, const T& val) { |
468 | ICHECK(data_ != nullptr) << "ValueError: cannot insert a null array" ; |
469 | int64_t idx = std::distance(begin(), position); |
470 | int64_t size = GetArrayNode()->size_; |
471 | auto addr = CopyOnWrite(1) // |
472 | ->EnlargeBy(1) // |
473 | ->MoveElementsRight(idx + 1, idx, size) // |
474 | ->MutableBegin(); |
475 | new (addr + idx) ObjectRef(val); |
476 | } |
477 | |
478 | /*! |
479 | * \brief Insert a range of elements into the given position |
480 | * \param position An iterator pointing to the insertion point |
481 | * \param first The begin iterator of the range |
482 | * \param last The end iterator of the range |
483 | */ |
484 | template <typename IterType> |
485 | void insert(iterator position, IterType first, IterType last) { |
486 | static_assert(is_valid_iterator_v<T, IterType>, |
487 | "IterType cannot be inserted into a tvm::Array<T>" ); |
488 | |
489 | if (first == last) { |
490 | return; |
491 | } |
492 | ICHECK(data_ != nullptr) << "ValueError: cannot insert a null array" ; |
493 | int64_t idx = std::distance(begin(), position); |
494 | int64_t size = GetArrayNode()->size_; |
495 | int64_t numel = std::distance(first, last); |
496 | CopyOnWrite(numel) |
497 | ->EnlargeBy(numel) |
498 | ->MoveElementsRight(idx + numel, idx, size) |
499 | ->InitRange(idx, first, last); |
500 | } |
501 | |
502 | /*! \brief Remove the last item of the list */ |
503 | void pop_back() { |
504 | ICHECK(data_ != nullptr) << "ValueError: cannot pop_back because array is null" ; |
505 | int64_t size = GetArrayNode()->size_; |
506 | ICHECK_GT(size, 0) << "ValueError: cannot pop_back because array is empty" ; |
507 | CopyOnWrite()->ShrinkBy(1); |
508 | } |
509 | |
510 | /*! |
511 | * \brief Erase an element on the given position |
512 | * \param position An iterator pointing to the element to be erased |
513 | */ |
514 | void erase(iterator position) { |
515 | ICHECK(data_ != nullptr) << "ValueError: cannot erase a null array" ; |
516 | int64_t st = std::distance(begin(), position); |
517 | int64_t size = GetArrayNode()->size_; |
518 | ICHECK(0 <= st && st < size) << "ValueError: cannot erase at index " << st |
519 | << ", because Array size is " << size; |
520 | CopyOnWrite() // |
521 | ->MoveElementsLeft(st, st + 1, size) // |
522 | ->ShrinkBy(1); |
523 | } |
524 | |
525 | /*! |
526 | * \brief Erase a given range of elements |
527 | * \param first The begin iterator of the range |
528 | * \param last The end iterator of the range |
529 | */ |
530 | void erase(iterator first, iterator last) { |
531 | if (first == last) { |
532 | return; |
533 | } |
534 | ICHECK(data_ != nullptr) << "ValueError: cannot erase a null array" ; |
535 | int64_t size = GetArrayNode()->size_; |
536 | int64_t st = std::distance(begin(), first); |
537 | int64_t ed = std::distance(begin(), last); |
538 | ICHECK_LT(st, ed) << "ValueError: cannot erase array in range [" << st << ", " << ed << ")" ; |
539 | ICHECK(0 <= st && st <= size && 0 <= ed && ed <= size) |
540 | << "ValueError: cannot erase array in range [" << st << ", " << ed << ")" |
541 | << ", because array size is " << size; |
542 | CopyOnWrite() // |
543 | ->MoveElementsLeft(st, ed, size) // |
544 | ->ShrinkBy(ed - st); |
545 | } |
546 | |
547 | /*! |
548 | * \brief Resize the array. |
549 | * \param n The new size. |
550 | */ |
551 | void resize(int64_t n) { |
552 | ICHECK_GE(n, 0) << "ValueError: cannot resize an Array to negative size" ; |
553 | if (data_ == nullptr) { |
554 | SwitchContainer(n); |
555 | return; |
556 | } |
557 | int64_t size = GetArrayNode()->size_; |
558 | if (size < n) { |
559 | CopyOnWrite(n - size)->EnlargeBy(n - size); |
560 | } else if (size > n) { |
561 | CopyOnWrite()->ShrinkBy(size - n); |
562 | } |
563 | } |
564 | |
565 | /*! |
566 | * \brief Make sure the list has the capacity of at least n |
567 | * \param n lower bound of the capacity |
568 | */ |
569 | void reserve(int64_t n) { |
570 | if (data_ == nullptr || n > GetArrayNode()->capacity_) { |
571 | SwitchContainer(n); |
572 | } |
573 | } |
574 | |
575 | /*! \brief Release reference to all the elements */ |
576 | void clear() { |
577 | if (data_ != nullptr) { |
578 | ArrayNode* p = CopyOnWrite(); |
579 | p->clear(); |
580 | } |
581 | } |
582 | |
583 | public: |
584 | // Array's own methods |
585 | |
586 | /*! |
587 | * \brief set i-th element of the array. |
588 | * \param i The index |
589 | * \param value The value to be setted. |
590 | */ |
591 | void Set(int64_t i, T value) { |
592 | ArrayNode* p = this->CopyOnWrite(); |
593 | ICHECK(0 <= i && i < p->size_) |
594 | << "IndexError: indexing " << i << " on an array of size " << p->size_; |
595 | *(p->MutableBegin() + i) = std::move(value); |
596 | } |
597 | |
598 | /*! \return The underlying ArrayNode */ |
599 | ArrayNode* GetArrayNode() const { return static_cast<ArrayNode*>(data_.get()); } |
600 | |
601 | /*! |
602 | * \brief Helper function to apply a map function onto the array. |
603 | * |
604 | * \param fmap The transformation function T -> U. |
605 | * |
606 | * \tparam F The type of the mutation function. |
607 | * |
608 | * \tparam U The type of the returned array, inferred from the |
609 | * return type of F. If overridden by the user, must be something |
610 | * that is convertible from the return type of F. |
611 | * |
612 | * \note This function performs copy on write optimization. If |
613 | * `fmap` returns an object of type `T`, and all elements of the |
614 | * array are mapped to themselves, then the returned array will be |
615 | * the same as the original, and reference counts of the elements in |
616 | * the array will not be incremented. |
617 | * |
618 | * \return The transformed array. |
619 | */ |
620 | template <typename F, typename U = std::invoke_result_t<F, T>> |
621 | Array<U> Map(F fmap) const { |
622 | return Array<U>(MapHelper(data_, fmap)); |
623 | } |
624 | |
625 | /*! |
626 | * \brief Helper function to apply fmutate to mutate an array. |
627 | * \param fmutate The transformation function T -> T. |
628 | * \tparam F the type of the mutation function. |
629 | * \note This function performs copy on write optimization. |
630 | */ |
631 | template <typename F, typename = std::enable_if_t<std::is_same_v<T, std::invoke_result_t<F, T>>>> |
632 | void MutateByApply(F fmutate) { |
633 | data_ = MapHelper(std::move(data_), fmutate); |
634 | } |
635 | |
636 | /*! |
637 | * \brief reset the array to content from iterator. |
638 | * \param first begin of iterator |
639 | * \param last end of iterator |
640 | * \tparam IterType The type of iterator |
641 | */ |
642 | template <typename IterType> |
643 | void Assign(IterType first, IterType last) { |
644 | int64_t cap = std::distance(first, last); |
645 | ICHECK_GE(cap, 0) << "ValueError: cannot construct an Array of negative size" ; |
646 | ArrayNode* p = GetArrayNode(); |
647 | if (p != nullptr && data_.unique() && p->capacity_ >= cap) { |
648 | // do not have to make new space |
649 | p->clear(); |
650 | } else { |
651 | // create new space |
652 | data_ = ArrayNode::Empty(cap); |
653 | p = GetArrayNode(); |
654 | } |
655 | // To ensure exception safety, size is only incremented after the initialization succeeds |
656 | ObjectRef* itr = p->MutableBegin(); |
657 | for (int64_t& i = p->size_ = 0; i < cap; ++i, ++first, ++itr) { |
658 | new (itr) ObjectRef(*first); |
659 | } |
660 | } |
661 | |
662 | /*! |
663 | * \brief Copy on write semantics |
664 | * Do nothing if current handle is the unique copy of the array. |
665 | * Otherwise make a new copy of the array to ensure the current handle |
666 | * hold a unique copy. |
667 | * |
668 | * \return Handle to the internal node container(which ganrantees to be unique) |
669 | */ |
670 | ArrayNode* CopyOnWrite() { |
671 | if (data_ == nullptr) { |
672 | return SwitchContainer(ArrayNode::kInitSize); |
673 | } |
674 | if (!data_.unique()) { |
675 | return SwitchContainer(capacity()); |
676 | } |
677 | return static_cast<ArrayNode*>(data_.get()); |
678 | } |
679 | |
680 | /*! \brief specify container node */ |
681 | using ContainerType = ArrayNode; |
682 | |
683 | private: |
684 | /*! |
685 | * \brief Implement copy-on-write semantics, and ensures capacity is enough for extra elements. |
686 | * \param reserve_extra Number of extra slots needed |
687 | * \return ArrayNode pointer to the unique copy |
688 | */ |
689 | ArrayNode* CopyOnWrite(int64_t ) { |
690 | ArrayNode* p = GetArrayNode(); |
691 | if (p == nullptr) { |
692 | // necessary to get around the constexpr address issue before c++17 |
693 | const int64_t kInitSize = ArrayNode::kInitSize; |
694 | return SwitchContainer(std::max(kInitSize, reserve_extra)); |
695 | } |
696 | if (p->capacity_ >= p->size_ + reserve_extra) { |
697 | return CopyOnWrite(); |
698 | } |
699 | int64_t cap = p->capacity_ * ArrayNode::kIncFactor; |
700 | cap = std::max(cap, p->size_ + reserve_extra); |
701 | return SwitchContainer(cap); |
702 | } |
703 | |
704 | /*! |
705 | * \brief Move or copy the ArrayNode to new address with the given capacity |
706 | * \param capacity The capacity requirement of the new address |
707 | */ |
708 | ArrayNode* SwitchContainer(int64_t capacity) { |
709 | if (data_ == nullptr) { |
710 | data_ = ArrayNode::Empty(capacity); |
711 | } else if (data_.unique()) { |
712 | data_ = ArrayNode::MoveFrom(capacity, GetArrayNode()); |
713 | } else { |
714 | data_ = ArrayNode::CopyFrom(capacity, GetArrayNode()); |
715 | } |
716 | return static_cast<ArrayNode*>(data_.get()); |
717 | } |
718 | |
719 | /*! \brief Helper method for mutate/map |
720 | * |
721 | * A helper function used internally by both `Array::Map` and |
722 | * `Array::MutateInPlace`. Given an array of data, apply the |
723 | * mapping function to each element, returning the collected array. |
724 | * Applies both mutate-in-place and copy-on-write optimizations, if |
725 | * possible. |
726 | * |
727 | * \param data A pointer to the ArrayNode containing input data. |
728 | * Passed by value to allow for mutate-in-place optimizations. |
729 | * |
730 | * \param fmap The mapping function |
731 | * |
732 | * \tparam F The type of the mutation function. |
733 | * |
734 | * \tparam U The output type of the mutation function. Inferred |
735 | * from the callable type given. Must inherit from ObjectRef. |
736 | * |
737 | * \return The mapped array. Depending on whether mutate-in-place |
738 | * or copy-on-write optimizations were applicable, may be the same |
739 | * underlying array as the `data` parameter. |
740 | */ |
741 | template <typename F, typename U = std::invoke_result_t<F, T>> |
742 | static ObjectPtr<Object> MapHelper(ObjectPtr<Object> data, F fmap) { |
743 | if (data == nullptr) { |
744 | return nullptr; |
745 | } |
746 | |
747 | ICHECK(data->IsInstance<ArrayNode>()); |
748 | |
749 | constexpr bool is_same_output_type = std::is_same_v<T, U>; |
750 | |
751 | if constexpr (is_same_output_type) { |
752 | if (data.unique()) { |
753 | // Mutate-in-place path. Only allowed if the output type U is |
754 | // the same as type T, we have a mutable this*, and there are |
755 | // no other shared copies of the array. |
756 | auto arr = static_cast<ArrayNode*>(data.get()); |
757 | for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) { |
758 | T mapped = fmap(DowncastNoCheck<T>(std::move(*it))); |
759 | *it = std::move(mapped); |
760 | } |
761 | return data; |
762 | } |
763 | } |
764 | |
765 | constexpr bool compatible_types = is_valid_iterator_v<T, U*> || is_valid_iterator_v<U, T*>; |
766 | |
767 | ObjectPtr<ArrayNode> output = nullptr; |
768 | auto arr = static_cast<ArrayNode*>(data.get()); |
769 | |
770 | auto it = arr->begin(); |
771 | if constexpr (compatible_types) { |
772 | // Copy-on-write path, if the output Array<U> might be |
773 | // represented by the same underlying array as the existing |
774 | // Array<T>. Typically, this is for functions that map `T` to |
775 | // `T`, but can also apply to functions that map `T` to |
776 | // `Optional<T>`, or that map `T` to a subclass or superclass of |
777 | // `T`. |
778 | bool all_identical = true; |
779 | for (; it != arr->end(); it++) { |
780 | U mapped = fmap(DowncastNoCheck<T>(*it)); |
781 | if (!mapped.same_as(*it)) { |
782 | // At least one mapped element is different than the |
783 | // original. Therefore, prepare the output array, |
784 | // consisting of any previous elements that had mapped to |
785 | // themselves (if any), and the element that didn't map to |
786 | // itself. |
787 | all_identical = false; |
788 | output = ArrayNode::CreateRepeated(arr->size(), U()); |
789 | output->InitRange(0, arr->begin(), it); |
790 | output->SetItem(it - arr->begin(), std::move(mapped)); |
791 | it++; |
792 | break; |
793 | } |
794 | } |
795 | if (all_identical) { |
796 | return data; |
797 | } |
798 | } else { |
799 | // Path for incompatible types. The constexpr check for |
800 | // compatible types isn't strictly necessary, as the first |
801 | // mapped.same_as(*it) would return false, but we might as well |
802 | // avoid it altogether. |
803 | output = ArrayNode::CreateRepeated(arr->size(), U()); |
804 | } |
805 | |
806 | // Normal path for incompatible types, or post-copy path for |
807 | // copy-on-write instances. |
808 | // |
809 | // If the types are incompatible, then at this point `output` is |
810 | // empty, and `it` points to the first element of the input. |
811 | // |
812 | // If the types were compatible, then at this point `output` |
813 | // contains zero or more elements that mapped to themselves |
814 | // followed by the first element that does not map to itself, and |
815 | // `it` points to the element just after the first element that |
816 | // does not map to itself. Because at least one element has been |
817 | // changed, we no longer have the opportunity to avoid a copy, so |
818 | // we don't need to check the result. |
819 | // |
820 | // In both cases, `it` points to the next element to be processed, |
821 | // so we can either start or resume the iteration from that point, |
822 | // with no further checks on the result. |
823 | for (; it != arr->end(); it++) { |
824 | U mapped = fmap(DowncastNoCheck<T>(*it)); |
825 | output->SetItem(it - arr->begin(), std::move(mapped)); |
826 | } |
827 | |
828 | return output; |
829 | } |
830 | }; |
831 | |
832 | /*! |
833 | * \brief Concat two Arrays. |
834 | * \param lhs first Array to be concatenated. |
835 | * \param rhs second Array to be concatenated. |
836 | * \return The concatenated Array. Original Arrays are kept unchanged. |
837 | */ |
838 | template <typename T, |
839 | typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type> |
840 | inline Array<T> Concat(Array<T> lhs, const Array<T>& rhs) { |
841 | for (const auto& x : rhs) { |
842 | lhs.push_back(x); |
843 | } |
844 | return std::move(lhs); |
845 | } |
846 | |
847 | // Specialize make_object<ArrayNode> to make sure it is correct. |
848 | template <> |
849 | inline ObjectPtr<ArrayNode> make_object() { |
850 | return ArrayNode::Empty(); |
851 | } |
852 | |
853 | } // namespace runtime |
854 | |
855 | // expose the functions to the root namespace. |
856 | using runtime::Array; |
857 | using runtime::ArrayNode; |
858 | } // namespace tvm |
859 | |
860 | #endif // TVM_RUNTIME_CONTAINER_ARRAY_H_ |
861 | |