1 | /* |
---|---|

2 | * Licensed to the Apache Software Foundation (ASF) under one |

3 | * or more contributor license agreements. See the NOTICE file |

4 | * distributed with this work for additional information |

5 | * regarding copyright ownership. The ASF licenses this file |

6 | * to you under the Apache License, Version 2.0 (the |

7 | * "License"); you may not use this file except in compliance |

8 | * with the License. You may obtain a copy of the License at |

9 | * |

10 | * http://www.apache.org/licenses/LICENSE-2.0 |

11 | * |

12 | * Unless required by applicable law or agreed to in writing, |

13 | * software distributed under the License is distributed on an |

14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |

15 | * KIND, either express or implied. See the License for the |

16 | * specific language governing permissions and limitations |

17 | * under the License. |

18 | */ |

19 | |

20 | /*! |

21 | * \file tvm/runtime/container/array.h |

22 | * \brief Runtime Array container types. |

23 | */ |

24 | #ifndef TVM_RUNTIME_CONTAINER_ARRAY_H_ |

25 | #define TVM_RUNTIME_CONTAINER_ARRAY_H_ |

26 | |

27 | #include <algorithm> |

28 | #include <memory> |

29 | #include <type_traits> |

30 | #include <utility> |

31 | #include <vector> |

32 | |

33 | #include "./base.h" |

34 | #include "./optional.h" |

35 | |

36 | namespace tvm { |

37 | namespace runtime { |

38 | |

39 | /*! \brief array node content in array */ |

40 | class ArrayNode : public Object, public InplaceArrayBase<ArrayNode, ObjectRef> { |

41 | public: |

42 | /*! \return The size of the array */ |

43 | size_t size() const { return this->size_; } |

44 | |

45 | /*! |

46 | * \brief Read i-th element from array. |

47 | * \param i The index |

48 | * \return the i-th element. |

49 | */ |

50 | const ObjectRef at(int64_t i) const { return this->operator[](i); } |

51 | |

52 | /*! \return begin constant iterator */ |

53 | const ObjectRef* begin() const { return static_cast<ObjectRef*>(InplaceArrayBase::AddressOf(0)); } |

54 | |

55 | /*! \return end constant iterator */ |

56 | const ObjectRef* end() const { return begin() + size_; } |

57 | |

58 | /*! \brief Release reference to all the elements */ |

59 | void clear() { ShrinkBy(size_); } |

60 | |

61 | /*! |

62 | * \brief Set i-th element of the array in-place |

63 | * \param i The index |

64 | * \param item The value to be set |

65 | */ |

66 | void SetItem(int64_t i, ObjectRef item) { this->operator[](i) = std::move(item); } |

67 | |

68 | /*! |

69 | * \brief Constructs a container and copy from another |

70 | * \param cap The capacity of the container |

71 | * \param from Source of the copy |

72 | * \return Ref-counted ArrayNode requested |

73 | */ |

74 | static ObjectPtr<ArrayNode> CopyFrom(int64_t cap, ArrayNode* from) { |

75 | int64_t size = from->size_; |

76 | ICHECK_GE(cap, size) << "ValueError: not enough capacity"; |

77 | ObjectPtr<ArrayNode> p = ArrayNode::Empty(cap); |

78 | ObjectRef* write = p->MutableBegin(); |

79 | ObjectRef* read = from->MutableBegin(); |

80 | // To ensure exception safety, size is only incremented after the initialization succeeds |

81 | for (int64_t& i = p->size_ = 0; i < size; ++i) { |

82 | new (write++) ObjectRef(*read++); |

83 | } |

84 | return p; |

85 | } |

86 | |

87 | /*! |

88 | * \brief Constructs a container and move from another |

89 | * \param cap The capacity of the container |

90 | * \param from Source of the move |

91 | * \return Ref-counted ArrayNode requested |

92 | */ |

93 | static ObjectPtr<ArrayNode> MoveFrom(int64_t cap, ArrayNode* from) { |

94 | int64_t size = from->size_; |

95 | ICHECK_GE(cap, size) << "ValueError: not enough capacity"; |

96 | ObjectPtr<ArrayNode> p = ArrayNode::Empty(cap); |

97 | ObjectRef* write = p->MutableBegin(); |

98 | ObjectRef* read = from->MutableBegin(); |

99 | // To ensure exception safety, size is only incremented after the initialization succeeds |

100 | for (int64_t& i = p->size_ = 0; i < size; ++i) { |

101 | new (write++) ObjectRef(std::move(*read++)); |

102 | } |

103 | from->size_ = 0; |

104 | return p; |

105 | } |

106 | |

107 | /*! |

108 | * \brief Constructs a container with n elements. Each element is a copy of val |

109 | * \param n The size of the container |

110 | * \param val The init value |

111 | * \return Ref-counted ArrayNode requested |

112 | */ |

113 | static ObjectPtr<ArrayNode> CreateRepeated(int64_t n, const ObjectRef& val) { |

114 | ObjectPtr<ArrayNode> p = ArrayNode::Empty(n); |

115 | ObjectRef* itr = p->MutableBegin(); |

116 | for (int64_t& i = p->size_ = 0; i < n; ++i) { |

117 | new (itr++) ObjectRef(val); |

118 | } |

119 | return p; |

120 | } |

121 | |

122 | static constexpr const uint32_t _type_index = TypeIndex::kRuntimeArray; |

123 | static constexpr const char* _type_key = "Array"; |

124 | TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object); |

125 | |

126 | private: |

127 | /*! \return Size of initialized memory, used by InplaceArrayBase. */ |

128 | size_t GetSize() const { return this->size_; } |

129 | |

130 | /*! \return begin mutable iterator */ |

131 | ObjectRef* MutableBegin() const { |

132 | return static_cast<ObjectRef*>(InplaceArrayBase::AddressOf(0)); |

133 | } |

134 | |

135 | /*! \return end mutable iterator */ |

136 | ObjectRef* MutableEnd() const { return MutableBegin() + size_; } |

137 | |

138 | /*! |

139 | * \brief Create an ArrayNode with the given capacity. |

140 | * \param n Required capacity |

141 | * \return Ref-counted ArrayNode requested |

142 | */ |

143 | static ObjectPtr<ArrayNode> Empty(int64_t n = kInitSize) { |

144 | ICHECK_GE(n, 0); |

145 | ObjectPtr<ArrayNode> p = make_inplace_array_object<ArrayNode, ObjectRef>(n); |

146 | p->capacity_ = n; |

147 | p->size_ = 0; |

148 | return p; |

149 | } |

150 | |

151 | /*! |

152 | * \brief Inplace-initialize the elements starting idx from [first, last) |

153 | * \param idx The starting point |

154 | * \param first Begin of iterator |

155 | * \param last End of iterator |

156 | * \tparam IterType The type of iterator |

157 | * \return Self |

158 | */ |

159 | template <typename IterType> |

160 | ArrayNode* InitRange(int64_t idx, IterType first, IterType last) { |

161 | ObjectRef* itr = MutableBegin() + idx; |

162 | for (; first != last; ++first) { |

163 | ObjectRef ref = *first; |

164 | new (itr++) ObjectRef(std::move(ref)); |

165 | } |

166 | return this; |

167 | } |

168 | |

169 | /*! |

170 | * \brief Move elements from right to left, requires src_begin > dst |

171 | * \param dst Destination |

172 | * \param src_begin The start point of copy (inclusive) |

173 | * \param src_end The end point of copy (exclusive) |

174 | * \return Self |

175 | */ |

176 | ArrayNode* MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) { |

177 | ObjectRef* from = MutableBegin() + src_begin; |

178 | ObjectRef* to = MutableBegin() + dst; |

179 | while (src_begin++ != src_end) { |

180 | *to++ = std::move(*from++); |

181 | } |

182 | return this; |

183 | } |

184 | |

185 | /*! |

186 | * \brief Move elements from left to right, requires src_begin < dst |

187 | * \param dst Destination |

188 | * \param src_begin The start point of move (inclusive) |

189 | * \param src_end The end point of move (exclusive) |

190 | * \return Self |

191 | */ |

192 | ArrayNode* MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end) { |

193 | ObjectRef* from = MutableBegin() + src_end; |

194 | ObjectRef* to = MutableBegin() + (src_end - src_begin + dst); |

195 | while (src_begin++ != src_end) { |

196 | *--to = std::move(*--from); |

197 | } |

198 | return this; |

199 | } |

200 | |

201 | /*! |

202 | * \brief Enlarges the size of the array |

203 | * \param delta Size enlarged, should be positive |

204 | * \param val Default value |

205 | * \return Self |

206 | */ |

207 | ArrayNode* EnlargeBy(int64_t delta, const ObjectRef& val = ObjectRef(nullptr)) { |

208 | ObjectRef* itr = MutableEnd(); |

209 | while (delta-- > 0) { |

210 | new (itr++) ObjectRef(val); |

211 | ++size_; |

212 | } |

213 | return this; |

214 | } |

215 | |

216 | /*! |

217 | * \brief Shrinks the size of the array |

218 | * \param delta Size shrinked, should be positive |

219 | * \return Self |

220 | */ |

221 | ArrayNode* ShrinkBy(int64_t delta) { |

222 | ObjectRef* itr = MutableEnd(); |

223 | while (delta-- > 0) { |

224 | (--itr)->ObjectRef::~ObjectRef(); |

225 | --size_; |

226 | } |

227 | return this; |

228 | } |

229 | |

230 | /*! \brief Number of elements used */ |

231 | int64_t size_; |

232 | |

233 | /*! \brief Number of elements allocated */ |

234 | int64_t capacity_; |

235 | |

236 | /*! \brief Initial size of ArrayNode */ |

237 | static constexpr int64_t kInitSize = 4; |

238 | |

239 | /*! \brief Expansion factor of the Array */ |

240 | static constexpr int64_t kIncFactor = 2; |

241 | |

242 | // CRTP parent class |

243 | friend InplaceArrayBase<ArrayNode, ObjectRef>; |

244 | |

245 | // Reference class |

246 | template <typename, typename> |

247 | friend class Array; |

248 | |

249 | // To specialize make_object<ArrayNode> |

250 | friend ObjectPtr<ArrayNode> make_object<>(); |

251 | }; |

252 | |

253 | /*! \brief Helper struct for type-checking |

254 | * |

255 | * is_valid_iterator<T,IterType>::value will be true if IterType can |

256 | * be dereferenced into a type that can be stored in an Array<T>, and |

257 | * false otherwise. |

258 | */ |

259 | template <typename T, typename IterType> |

260 | struct is_valid_iterator |

261 | : std::bool_constant<std::is_base_of_v< |

262 | T, std::remove_cv_t<std::remove_reference_t<decltype(*std::declval<IterType>())>>>> {}; |

263 | |

264 | template <typename T, typename IterType> |

265 | struct is_valid_iterator<Optional<T>, IterType> : is_valid_iterator<T, IterType> {}; |

266 | |

267 | template <typename T, typename IterType> |

268 | inline constexpr bool is_valid_iterator_v = is_valid_iterator<T, IterType>::value; |

269 | |

270 | /*! |

271 | * \brief Array, container representing a contiguous sequence of ObjectRefs. |

272 | * |

273 | * Array implements in-place copy-on-write semantics. |

274 | * |

275 | * As in typical copy-on-write, a method which would typically mutate the array |

276 | * instead opaquely copies the underlying container, and then acts on its copy. |

277 | * |

278 | * If the array has reference count equal to one, we directly update the |

279 | * container in place without copying. This is optimization is sound because |

280 | * when the reference count is equal to one this reference is guranteed to be |

281 | * the sole pointer to the container. |

282 | * |

283 | * |

284 | * operator[] only provides const access, use Set to mutate the content. |

285 | * \tparam T The content ObjectRef type. |

286 | */ |

287 | template <typename T, |

288 | typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type> |

289 | class Array : public ObjectRef { |

290 | public: |

291 | using value_type = T; |

292 | // constructors |

293 | /*! |

294 | * \brief default constructor |

295 | */ |

296 | Array() { data_ = ArrayNode::Empty(); } |

297 | |

298 | /*! |

299 | * \brief move constructor |

300 | * \param other source |

301 | */ |

302 | Array(Array<T>&& other) : ObjectRef() { // NOLINT(*) |

303 | data_ = std::move(other.data_); |

304 | } |

305 | |

306 | /*! |

307 | * \brief copy constructor |

308 | * \param other source |

309 | */ |

310 | Array(const Array<T>& other) : ObjectRef() { // NOLINT(*) |

311 | data_ = other.data_; |

312 | } |

313 | |

314 | /*! |

315 | * \brief constructor from pointer |

316 | * \param n the container pointer |

317 | */ |

318 | explicit Array(ObjectPtr<Object> n) : ObjectRef(n) {} |

319 | |

320 | /*! |

321 | * \brief Constructor from iterator |

322 | * \param first begin of iterator |

323 | * \param last end of iterator |

324 | * \tparam IterType The type of iterator |

325 | */ |

326 | template <typename IterType> |

327 | Array(IterType first, IterType last) { |

328 | static_assert(is_valid_iterator_v<T, IterType>, |

329 | "IterType cannot be inserted into a tvm::Array<T>"); |

330 | Assign(first, last); |

331 | } |

332 | |

333 | /*! |

334 | * \brief constructor from initializer list |

335 | * \param init The initializer list |

336 | */ |

337 | Array(std::initializer_list<T> init) { // NOLINT(*) |

338 | Assign(init.begin(), init.end()); |

339 | } |

340 | |

341 | /*! |

342 | * \brief constructor from vector |

343 | * \param init The vector |

344 | */ |

345 | Array(const std::vector<T>& init) { // NOLINT(*) |

346 | Assign(init.begin(), init.end()); |

347 | } |

348 | |

349 | /*! |

350 | * \brief Constructs a container with n elements. Each element is a copy of val |

351 | * \param n The size of the container |

352 | * \param val The init value |

353 | */ |

354 | explicit Array(const size_t n, const T& val) { data_ = ArrayNode::CreateRepeated(n, val); } |

355 | |

356 | /*! |

357 | * \brief move assign operator |

358 | * \param other The source of assignment |

359 | * \return reference to self. |

360 | */ |

361 | Array<T>& operator=(Array<T>&& other) { |

362 | data_ = std::move(other.data_); |

363 | return *this; |

364 | } |

365 | |

366 | /*! |

367 | * \brief copy assign operator |

368 | * \param other The source of assignment |

369 | * \return reference to self. |

370 | */ |

371 | Array<T>& operator=(const Array<T>& other) { |

372 | data_ = other.data_; |

373 | return *this; |

374 | } |

375 | |

376 | public: |

377 | // iterators |

378 | struct ValueConverter { |

379 | using ResultType = T; |

380 | static T convert(const ObjectRef& n) { return DowncastNoCheck<T>(n); } |

381 | }; |

382 | |

383 | using iterator = IterAdapter<ValueConverter, const ObjectRef*>; |

384 | using reverse_iterator = ReverseIterAdapter<ValueConverter, const ObjectRef*>; |

385 | |

386 | /*! \return begin iterator */ |

387 | iterator begin() const { return iterator(GetArrayNode()->begin()); } |

388 | |

389 | /*! \return end iterator */ |

390 | iterator end() const { return iterator(GetArrayNode()->end()); } |

391 | |

392 | /*! \return rbegin iterator */ |

393 | reverse_iterator rbegin() const { |

394 | // ArrayNode::end() is never nullptr |

395 | return reverse_iterator(GetArrayNode()->end() - 1); |

396 | } |

397 | |

398 | /*! \return rend iterator */ |

399 | reverse_iterator rend() const { |

400 | // ArrayNode::begin() is never nullptr |

401 | return reverse_iterator(GetArrayNode()->begin() - 1); |

402 | } |

403 | |

404 | public: |

405 | // const methods in std::vector |

406 | /*! |

407 | * \brief Immutably read i-th element from array. |

408 | * \param i The index |

409 | * \return the i-th element. |

410 | */ |

411 | const T operator[](int64_t i) const { |

412 | ArrayNode* p = GetArrayNode(); |

413 | ICHECK(p != nullptr) << "ValueError: cannot index a null array"; |

414 | ICHECK(0 <= i && i < p->size_) |

415 | << "IndexError: indexing "<< i << " on an array of size "<< p->size_; |

416 | return DowncastNoCheck<T>(*(p->begin() + i)); |

417 | } |

418 | |

419 | /*! \return The size of the array */ |

420 | size_t size() const { |

421 | ArrayNode* p = GetArrayNode(); |

422 | return p == nullptr ? 0 : GetArrayNode()->size_; |

423 | } |

424 | |

425 | /*! \return The capacity of the array */ |

426 | size_t capacity() const { |

427 | ArrayNode* p = GetArrayNode(); |

428 | return p == nullptr ? 0 : GetArrayNode()->capacity_; |

429 | } |

430 | |

431 | /*! \return Whether array is empty */ |

432 | bool empty() const { return size() == 0; } |

433 | |

434 | /*! \return The first element of the array */ |

435 | const T front() const { |

436 | ArrayNode* p = GetArrayNode(); |

437 | ICHECK(p != nullptr) << "ValueError: cannot index a null array"; |

438 | ICHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array"; |

439 | return DowncastNoCheck<T>(*(p->begin())); |

440 | } |

441 | |

442 | /*! \return The last element of the array */ |

443 | const T back() const { |

444 | ArrayNode* p = GetArrayNode(); |

445 | ICHECK(p != nullptr) << "ValueError: cannot index a null array"; |

446 | ICHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array"; |

447 | return DowncastNoCheck<T>(*(p->end() - 1)); |

448 | } |

449 | |

450 | public: |

451 | // mutation in std::vector, implements copy-on-write |

452 | |

453 | /*! |

454 | * \brief push a new item to the back of the list |

455 | * \param item The item to be pushed. |

456 | */ |

457 | void push_back(const T& item) { |

458 | ArrayNode* p = CopyOnWrite(1); |

459 | p->EmplaceInit(p->size_++, item); |

460 | } |

461 | |

462 | /*! |

463 | * \brief Insert an element into the given position |

464 | * \param position An iterator pointing to the insertion point |

465 | * \param val The element to insert |

466 | */ |

467 | void insert(iterator position, const T& val) { |

468 | ICHECK(data_ != nullptr) << "ValueError: cannot insert a null array"; |

469 | int64_t idx = std::distance(begin(), position); |

470 | int64_t size = GetArrayNode()->size_; |

471 | auto addr = CopyOnWrite(1) // |

472 | ->EnlargeBy(1) // |

473 | ->MoveElementsRight(idx + 1, idx, size) // |

474 | ->MutableBegin(); |

475 | new (addr + idx) ObjectRef(val); |

476 | } |

477 | |

478 | /*! |

479 | * \brief Insert a range of elements into the given position |

480 | * \param position An iterator pointing to the insertion point |

481 | * \param first The begin iterator of the range |

482 | * \param last The end iterator of the range |

483 | */ |

484 | template <typename IterType> |

485 | void insert(iterator position, IterType first, IterType last) { |

486 | static_assert(is_valid_iterator_v<T, IterType>, |

487 | "IterType cannot be inserted into a tvm::Array<T>"); |

488 | |

489 | if (first == last) { |

490 | return; |

491 | } |

492 | ICHECK(data_ != nullptr) << "ValueError: cannot insert a null array"; |

493 | int64_t idx = std::distance(begin(), position); |

494 | int64_t size = GetArrayNode()->size_; |

495 | int64_t numel = std::distance(first, last); |

496 | CopyOnWrite(numel) |

497 | ->EnlargeBy(numel) |

498 | ->MoveElementsRight(idx + numel, idx, size) |

499 | ->InitRange(idx, first, last); |

500 | } |

501 | |

502 | /*! \brief Remove the last item of the list */ |

503 | void pop_back() { |

504 | ICHECK(data_ != nullptr) << "ValueError: cannot pop_back because array is null"; |

505 | int64_t size = GetArrayNode()->size_; |

506 | ICHECK_GT(size, 0) << "ValueError: cannot pop_back because array is empty"; |

507 | CopyOnWrite()->ShrinkBy(1); |

508 | } |

509 | |

510 | /*! |

511 | * \brief Erase an element on the given position |

512 | * \param position An iterator pointing to the element to be erased |

513 | */ |

514 | void erase(iterator position) { |

515 | ICHECK(data_ != nullptr) << "ValueError: cannot erase a null array"; |

516 | int64_t st = std::distance(begin(), position); |

517 | int64_t size = GetArrayNode()->size_; |

518 | ICHECK(0 <= st && st < size) << "ValueError: cannot erase at index "<< st |

519 | << ", because Array size is "<< size; |

520 | CopyOnWrite() // |

521 | ->MoveElementsLeft(st, st + 1, size) // |

522 | ->ShrinkBy(1); |

523 | } |

524 | |

525 | /*! |

526 | * \brief Erase a given range of elements |

527 | * \param first The begin iterator of the range |

528 | * \param last The end iterator of the range |

529 | */ |

530 | void erase(iterator first, iterator last) { |

531 | if (first == last) { |

532 | return; |

533 | } |

534 | ICHECK(data_ != nullptr) << "ValueError: cannot erase a null array"; |

535 | int64_t size = GetArrayNode()->size_; |

536 | int64_t st = std::distance(begin(), first); |

537 | int64_t ed = std::distance(begin(), last); |

538 | ICHECK_LT(st, ed) << "ValueError: cannot erase array in range ["<< st << ", "<< ed << ")"; |

539 | ICHECK(0 <= st && st <= size && 0 <= ed && ed <= size) |

540 | << "ValueError: cannot erase array in range ["<< st << ", "<< ed << ")" |

541 | << ", because array size is "<< size; |

542 | CopyOnWrite() // |

543 | ->MoveElementsLeft(st, ed, size) // |

544 | ->ShrinkBy(ed - st); |

545 | } |

546 | |

547 | /*! |

548 | * \brief Resize the array. |

549 | * \param n The new size. |

550 | */ |

551 | void resize(int64_t n) { |

552 | ICHECK_GE(n, 0) << "ValueError: cannot resize an Array to negative size"; |

553 | if (data_ == nullptr) { |

554 | SwitchContainer(n); |

555 | return; |

556 | } |

557 | int64_t size = GetArrayNode()->size_; |

558 | if (size < n) { |

559 | CopyOnWrite(n - size)->EnlargeBy(n - size); |

560 | } else if (size > n) { |

561 | CopyOnWrite()->ShrinkBy(size - n); |

562 | } |

563 | } |

564 | |

565 | /*! |

566 | * \brief Make sure the list has the capacity of at least n |

567 | * \param n lower bound of the capacity |

568 | */ |

569 | void reserve(int64_t n) { |

570 | if (data_ == nullptr || n > GetArrayNode()->capacity_) { |

571 | SwitchContainer(n); |

572 | } |

573 | } |

574 | |

575 | /*! \brief Release reference to all the elements */ |

576 | void clear() { |

577 | if (data_ != nullptr) { |

578 | ArrayNode* p = CopyOnWrite(); |

579 | p->clear(); |

580 | } |

581 | } |

582 | |

583 | public: |

584 | // Array's own methods |

585 | |

586 | /*! |

587 | * \brief set i-th element of the array. |

588 | * \param i The index |

589 | * \param value The value to be setted. |

590 | */ |

591 | void Set(int64_t i, T value) { |

592 | ArrayNode* p = this->CopyOnWrite(); |

593 | ICHECK(0 <= i && i < p->size_) |

594 | << "IndexError: indexing "<< i << " on an array of size "<< p->size_; |

595 | *(p->MutableBegin() + i) = std::move(value); |

596 | } |

597 | |

598 | /*! \return The underlying ArrayNode */ |

599 | ArrayNode* GetArrayNode() const { return static_cast<ArrayNode*>(data_.get()); } |

600 | |

601 | /*! |

602 | * \brief Helper function to apply a map function onto the array. |

603 | * |

604 | * \param fmap The transformation function T -> U. |

605 | * |

606 | * \tparam F The type of the mutation function. |

607 | * |

608 | * \tparam U The type of the returned array, inferred from the |

609 | * return type of F. If overridden by the user, must be something |

610 | * that is convertible from the return type of F. |

611 | * |

612 | * \note This function performs copy on write optimization. If |

613 | * `fmap` returns an object of type `T`, and all elements of the |

614 | * array are mapped to themselves, then the returned array will be |

615 | * the same as the original, and reference counts of the elements in |

616 | * the array will not be incremented. |

617 | * |

618 | * \return The transformed array. |

619 | */ |

620 | template <typename F, typename U = std::invoke_result_t<F, T>> |

621 | Array<U> Map(F fmap) const { |

622 | return Array<U>(MapHelper(data_, fmap)); |

623 | } |

624 | |

625 | /*! |

626 | * \brief Helper function to apply fmutate to mutate an array. |

627 | * \param fmutate The transformation function T -> T. |

628 | * \tparam F the type of the mutation function. |

629 | * \note This function performs copy on write optimization. |

630 | */ |

631 | template <typename F, typename = std::enable_if_t<std::is_same_v<T, std::invoke_result_t<F, T>>>> |

632 | void MutateByApply(F fmutate) { |

633 | data_ = MapHelper(std::move(data_), fmutate); |

634 | } |

635 | |

636 | /*! |

637 | * \brief reset the array to content from iterator. |

638 | * \param first begin of iterator |

639 | * \param last end of iterator |

640 | * \tparam IterType The type of iterator |

641 | */ |

642 | template <typename IterType> |

643 | void Assign(IterType first, IterType last) { |

644 | int64_t cap = std::distance(first, last); |

645 | ICHECK_GE(cap, 0) << "ValueError: cannot construct an Array of negative size"; |

646 | ArrayNode* p = GetArrayNode(); |

647 | if (p != nullptr && data_.unique() && p->capacity_ >= cap) { |

648 | // do not have to make new space |

649 | p->clear(); |

650 | } else { |

651 | // create new space |

652 | data_ = ArrayNode::Empty(cap); |

653 | p = GetArrayNode(); |

654 | } |

655 | // To ensure exception safety, size is only incremented after the initialization succeeds |

656 | ObjectRef* itr = p->MutableBegin(); |

657 | for (int64_t& i = p->size_ = 0; i < cap; ++i, ++first, ++itr) { |

658 | new (itr) ObjectRef(*first); |

659 | } |

660 | } |

661 | |

662 | /*! |

663 | * \brief Copy on write semantics |

664 | * Do nothing if current handle is the unique copy of the array. |

665 | * Otherwise make a new copy of the array to ensure the current handle |

666 | * hold a unique copy. |

667 | * |

668 | * \return Handle to the internal node container(which ganrantees to be unique) |

669 | */ |

670 | ArrayNode* CopyOnWrite() { |

671 | if (data_ == nullptr) { |

672 | return SwitchContainer(ArrayNode::kInitSize); |

673 | } |

674 | if (!data_.unique()) { |

675 | return SwitchContainer(capacity()); |

676 | } |

677 | return static_cast<ArrayNode*>(data_.get()); |

678 | } |

679 | |

680 | /*! \brief specify container node */ |

681 | using ContainerType = ArrayNode; |

682 | |

683 | private: |

684 | /*! |

685 | * \brief Implement copy-on-write semantics, and ensures capacity is enough for extra elements. |

686 | * \param reserve_extra Number of extra slots needed |

687 | * \return ArrayNode pointer to the unique copy |

688 | */ |

689 | ArrayNode* CopyOnWrite(int64_t reserve_extra) { |

690 | ArrayNode* p = GetArrayNode(); |

691 | if (p == nullptr) { |

692 | // necessary to get around the constexpr address issue before c++17 |

693 | const int64_t kInitSize = ArrayNode::kInitSize; |

694 | return SwitchContainer(std::max(kInitSize, reserve_extra)); |

695 | } |

696 | if (p->capacity_ >= p->size_ + reserve_extra) { |

697 | return CopyOnWrite(); |

698 | } |

699 | int64_t cap = p->capacity_ * ArrayNode::kIncFactor; |

700 | cap = std::max(cap, p->size_ + reserve_extra); |

701 | return SwitchContainer(cap); |

702 | } |

703 | |

704 | /*! |

705 | * \brief Move or copy the ArrayNode to new address with the given capacity |

706 | * \param capacity The capacity requirement of the new address |

707 | */ |

708 | ArrayNode* SwitchContainer(int64_t capacity) { |

709 | if (data_ == nullptr) { |

710 | data_ = ArrayNode::Empty(capacity); |

711 | } else if (data_.unique()) { |

712 | data_ = ArrayNode::MoveFrom(capacity, GetArrayNode()); |

713 | } else { |

714 | data_ = ArrayNode::CopyFrom(capacity, GetArrayNode()); |

715 | } |

716 | return static_cast<ArrayNode*>(data_.get()); |

717 | } |

718 | |

719 | /*! \brief Helper method for mutate/map |

720 | * |

721 | * A helper function used internally by both `Array::Map` and |

722 | * `Array::MutateInPlace`. Given an array of data, apply the |

723 | * mapping function to each element, returning the collected array. |

724 | * Applies both mutate-in-place and copy-on-write optimizations, if |

725 | * possible. |

726 | * |

727 | * \param data A pointer to the ArrayNode containing input data. |

728 | * Passed by value to allow for mutate-in-place optimizations. |

729 | * |

730 | * \param fmap The mapping function |

731 | * |

732 | * \tparam F The type of the mutation function. |

733 | * |

734 | * \tparam U The output type of the mutation function. Inferred |

735 | * from the callable type given. Must inherit from ObjectRef. |

736 | * |

737 | * \return The mapped array. Depending on whether mutate-in-place |

738 | * or copy-on-write optimizations were applicable, may be the same |

739 | * underlying array as the `data` parameter. |

740 | */ |

741 | template <typename F, typename U = std::invoke_result_t<F, T>> |

742 | static ObjectPtr<Object> MapHelper(ObjectPtr<Object> data, F fmap) { |

743 | if (data == nullptr) { |

744 | return nullptr; |

745 | } |

746 | |

747 | ICHECK(data->IsInstance<ArrayNode>()); |

748 | |

749 | constexpr bool is_same_output_type = std::is_same_v<T, U>; |

750 | |

751 | if constexpr (is_same_output_type) { |

752 | if (data.unique()) { |

753 | // Mutate-in-place path. Only allowed if the output type U is |

754 | // the same as type T, we have a mutable this*, and there are |

755 | // no other shared copies of the array. |

756 | auto arr = static_cast<ArrayNode*>(data.get()); |

757 | for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) { |

758 | T mapped = fmap(DowncastNoCheck<T>(std::move(*it))); |

759 | *it = std::move(mapped); |

760 | } |

761 | return data; |

762 | } |

763 | } |

764 | |

765 | constexpr bool compatible_types = is_valid_iterator_v<T, U*> || is_valid_iterator_v<U, T*>; |

766 | |

767 | ObjectPtr<ArrayNode> output = nullptr; |

768 | auto arr = static_cast<ArrayNode*>(data.get()); |

769 | |

770 | auto it = arr->begin(); |

771 | if constexpr (compatible_types) { |

772 | // Copy-on-write path, if the output Array<U> might be |

773 | // represented by the same underlying array as the existing |

774 | // Array<T>. Typically, this is for functions that map `T` to |

775 | // `T`, but can also apply to functions that map `T` to |

776 | // `Optional<T>`, or that map `T` to a subclass or superclass of |

777 | // `T`. |

778 | bool all_identical = true; |

779 | for (; it != arr->end(); it++) { |

780 | U mapped = fmap(DowncastNoCheck<T>(*it)); |

781 | if (!mapped.same_as(*it)) { |

782 | // At least one mapped element is different than the |

783 | // original. Therefore, prepare the output array, |

784 | // consisting of any previous elements that had mapped to |

785 | // themselves (if any), and the element that didn't map to |

786 | // itself. |

787 | all_identical = false; |

788 | output = ArrayNode::CreateRepeated(arr->size(), U()); |

789 | output->InitRange(0, arr->begin(), it); |

790 | output->SetItem(it - arr->begin(), std::move(mapped)); |

791 | it++; |

792 | break; |

793 | } |

794 | } |

795 | if (all_identical) { |

796 | return data; |

797 | } |

798 | } else { |

799 | // Path for incompatible types. The constexpr check for |

800 | // compatible types isn't strictly necessary, as the first |

801 | // mapped.same_as(*it) would return false, but we might as well |

802 | // avoid it altogether. |

803 | output = ArrayNode::CreateRepeated(arr->size(), U()); |

804 | } |

805 | |

806 | // Normal path for incompatible types, or post-copy path for |

807 | // copy-on-write instances. |

808 | // |

809 | // If the types are incompatible, then at this point `output` is |

810 | // empty, and `it` points to the first element of the input. |

811 | // |

812 | // If the types were compatible, then at this point `output` |

813 | // contains zero or more elements that mapped to themselves |

814 | // followed by the first element that does not map to itself, and |

815 | // `it` points to the element just after the first element that |

816 | // does not map to itself. Because at least one element has been |

817 | // changed, we no longer have the opportunity to avoid a copy, so |

818 | // we don't need to check the result. |

819 | // |

820 | // In both cases, `it` points to the next element to be processed, |

821 | // so we can either start or resume the iteration from that point, |

822 | // with no further checks on the result. |

823 | for (; it != arr->end(); it++) { |

824 | U mapped = fmap(DowncastNoCheck<T>(*it)); |

825 | output->SetItem(it - arr->begin(), std::move(mapped)); |

826 | } |

827 | |

828 | return output; |

829 | } |

830 | }; |

831 | |

832 | /*! |

833 | * \brief Concat two Arrays. |

834 | * \param lhs first Array to be concatenated. |

835 | * \param rhs second Array to be concatenated. |

836 | * \return The concatenated Array. Original Arrays are kept unchanged. |

837 | */ |

838 | template <typename T, |

839 | typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type> |

840 | inline Array<T> Concat(Array<T> lhs, const Array<T>& rhs) { |

841 | for (const auto& x : rhs) { |

842 | lhs.push_back(x); |

843 | } |

844 | return std::move(lhs); |

845 | } |

846 | |

847 | // Specialize make_object<ArrayNode> to make sure it is correct. |

848 | template <> |

849 | inline ObjectPtr<ArrayNode> make_object() { |

850 | return ArrayNode::Empty(); |

851 | } |

852 | |

853 | } // namespace runtime |

854 | |

855 | // expose the functions to the root namespace. |

856 | using runtime::Array; |

857 | using runtime::ArrayNode; |

858 | } // namespace tvm |

859 | |

860 | #endif // TVM_RUNTIME_CONTAINER_ARRAY_H_ |

861 |