1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/runtime/container/map.h
22 * \brief Runtime Map container types.
23 */
24#ifndef TVM_RUNTIME_CONTAINER_MAP_H_
25#define TVM_RUNTIME_CONTAINER_MAP_H_
26
27#ifndef USE_FALLBACK_STL_MAP
28#define USE_FALLBACK_STL_MAP 0
29#endif
30
31#include <algorithm>
32#include <unordered_map>
33#include <utility>
34
35#include "./base.h"
36#include "./optional.h"
37
38namespace tvm {
39namespace runtime {
40
41#if TVM_LOG_DEBUG
42#define TVM_MAP_FAIL_IF_CHANGED() \
43 ICHECK(state_marker == self->state_marker) << "Concurrent modification of the Map";
44#else
45#define TVM_MAP_FAIL_IF_CHANGED()
46#endif // TVM_LOG_DEBUG
47
48#if (USE_FALLBACK_STL_MAP != 0)
49
50/*! \brief Shared content of all specializations of hash map */
51class MapNode : public Object {
52 public:
53 /*! \brief Type of the keys in the hash map */
54 using key_type = ObjectRef;
55 /*! \brief Type of the values in the hash map */
56 using mapped_type = ObjectRef;
57 /*! \brief Type of the actual underlying container */
58 using ContainerType = std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual>;
59 /*! \brief Iterator class */
60 using iterator = ContainerType::iterator;
61 /*! \brief Iterator class */
62 using const_iterator = ContainerType::const_iterator;
63 /*! \brief Type of value stored in the hash map */
64 using KVType = ContainerType::value_type;
65
66 static_assert(std::is_standard_layout<KVType>::value, "KVType is not standard layout");
67 static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect");
68
69 static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap;
70 static constexpr const char* _type_key = "Map";
71 TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object);
72
73 /*!
74 * \brief Number of elements in the SmallMapNode
75 * \return The result
76 */
77 size_t size() const { return data_.size(); }
78 /*!
79 * \brief Count the number of times a key exists in the hash map
80 * \param key The indexing key
81 * \return The result, 0 or 1
82 */
83 size_t count(const key_type& key) const { return data_.count(key); }
84 /*!
85 * \brief Index value associated with a key, throw exception if the key does not exist
86 * \param key The indexing key
87 * \return The const reference to the value
88 */
89 const mapped_type& at(const key_type& key) const { return data_.at(key); }
90 /*!
91 * \brief Index value associated with a key, throw exception if the key does not exist
92 * \param key The indexing key
93 * \return The mutable reference to the value
94 */
95 mapped_type& at(const key_type& key) { return data_.at(key); }
96 /*! \return begin iterator */
97 iterator begin() { return data_.begin(); }
98 /*! \return const begin iterator */
99 const_iterator begin() const { return data_.begin(); }
100 /*! \return end iterator */
101 iterator end() { return data_.end(); }
102 /*! \return end iterator */
103 const_iterator end() const { return data_.end(); }
104 /*!
105 * \brief Index value associated with a key
106 * \param key The indexing key
107 * \return The iterator of the entry associated with the key, end iterator if not exists
108 */
109 const_iterator find(const key_type& key) const { return data_.find(key); }
110 /*!
111 * \brief Index value associated with a key
112 * \param key The indexing key
113 * \return The iterator of the entry associated with the key, end iterator if not exists
114 */
115 iterator find(const key_type& key) { return data_.find(key); }
116 /*!
117 * \brief Erase the entry associated with the iterator
118 * \param position The iterator
119 */
120 void erase(const iterator& position) { data_.erase(position); }
121 /*!
122 * \brief Erase the entry associated with the key, do nothing if not exists
123 * \param key The indexing key
124 */
125 void erase(const key_type& key) { data_.erase(key); }
126 /*!
127 * \brief Create an empty container
128 * \return The object created
129 */
130 static ObjectPtr<MapNode> Empty() { return make_object<MapNode>(); }
131
132 protected:
133 /*!
134 * \brief Create the map using contents from the given iterators.
135 * \param first Begin of iterator
136 * \param last End of iterator
137 * \tparam IterType The type of iterator
138 * \return ObjectPtr to the map created
139 */
140 template <typename IterType>
141 static ObjectPtr<Object> CreateFromRange(IterType first, IterType last) {
142 ObjectPtr<MapNode> p = make_object<MapNode>();
143 p->data_ = ContainerType(first, last);
144 return p;
145 }
146 /*!
147 * \brief InsertMaybeReHash an entry into the given hash map
148 * \param kv The entry to be inserted
149 * \param map The pointer to the map, can be changed if re-hashing happens
150 */
151 static void InsertMaybeReHash(const KVType& kv, ObjectPtr<Object>* map) {
152 MapNode* map_node = static_cast<MapNode*>(map->get());
153 map_node->data_[kv.first] = kv.second;
154 }
155 /*!
156 * \brief Create an empty container with elements copying from another MapNode
157 * \param from The source container
158 * \return The object created
159 */
160 static ObjectPtr<MapNode> CopyFrom(MapNode* from) {
161 ObjectPtr<MapNode> p = make_object<MapNode>();
162 p->data_ = ContainerType(from->data_.begin(), from->data_.end());
163 return p;
164 }
165 /*! \brief The real container storing data */
166 ContainerType data_;
167 template <typename, typename, typename, typename>
168 friend class Map;
169};
170
171#else
172
173/*! \brief Shared content of all specializations of hash map */
174class MapNode : public Object {
175 public:
176 /*! \brief Type of the keys in the hash map */
177 using key_type = ObjectRef;
178 /*! \brief Type of the values in the hash map */
179 using mapped_type = ObjectRef;
180 /*! \brief Type of value stored in the hash map */
181 using KVType = std::pair<ObjectRef, ObjectRef>;
182 /*! \brief Iterator class */
183 class iterator;
184
185 static_assert(std::is_standard_layout<KVType>::value, "KVType is not standard layout");
186 static_assert(sizeof(KVType) == 16 || sizeof(KVType) == 8, "sizeof(KVType) incorrect");
187
188 static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap;
189 static constexpr const char* _type_key = "Map";
190 TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object);
191
192 /*!
193 * \brief Number of elements in the SmallMapNode
194 * \return The result
195 */
196 size_t size() const { return size_; }
197 /*!
198 * \brief Count the number of times a key exists in the hash map
199 * \param key The indexing key
200 * \return The result, 0 or 1
201 */
202 size_t count(const key_type& key) const;
203 /*!
204 * \brief Index value associated with a key, throw exception if the key does not exist
205 * \param key The indexing key
206 * \return The const reference to the value
207 */
208 const mapped_type& at(const key_type& key) const;
209 /*!
210 * \brief Index value associated with a key, throw exception if the key does not exist
211 * \param key The indexing key
212 * \return The mutable reference to the value
213 */
214 mapped_type& at(const key_type& key);
215 /*! \return begin iterator */
216 iterator begin() const;
217 /*! \return end iterator */
218 iterator end() const;
219 /*!
220 * \brief Index value associated with a key
221 * \param key The indexing key
222 * \return The iterator of the entry associated with the key, end iterator if not exists
223 */
224 iterator find(const key_type& key) const;
225 /*!
226 * \brief Erase the entry associated with the iterator
227 * \param position The iterator
228 */
229 void erase(const iterator& position);
230 /*!
231 * \brief Erase the entry associated with the key, do nothing if not exists
232 * \param key The indexing key
233 */
234 void erase(const key_type& key) { erase(find(key)); }
235
236 class iterator {
237 public:
238 using iterator_category = std::forward_iterator_tag;
239 using difference_type = int64_t;
240 using value_type = KVType;
241 using pointer = KVType*;
242 using reference = KVType&;
243/*! \brief Default constructor */
244#if TVM_LOG_DEBUG
245 iterator() : state_marker(0), index(0), self(nullptr) {}
246#else
247 iterator() : index(0), self(nullptr) {}
248#endif // TVM_LOG_DEBUG
249 /*! \brief Compare iterators */
250 bool operator==(const iterator& other) const {
251 TVM_MAP_FAIL_IF_CHANGED()
252 return index == other.index && self == other.self;
253 }
254 /*! \brief Compare iterators */
255 bool operator!=(const iterator& other) const { return !(*this == other); }
256 /*! \brief De-reference iterators */
257 pointer operator->() const;
258 /*! \brief De-reference iterators */
259 reference operator*() const {
260 TVM_MAP_FAIL_IF_CHANGED()
261 return *((*this).operator->());
262 }
263 /*! \brief Prefix self increment, e.g. ++iter */
264 iterator& operator++();
265 /*! \brief Prefix self decrement, e.g. --iter */
266 iterator& operator--();
267 /*! \brief Suffix self increment */
268 iterator operator++(int) {
269 TVM_MAP_FAIL_IF_CHANGED()
270 iterator copy = *this;
271 ++(*this);
272 return copy;
273 }
274 /*! \brief Suffix self decrement */
275 iterator operator--(int) {
276 TVM_MAP_FAIL_IF_CHANGED()
277 iterator copy = *this;
278 --(*this);
279 return copy;
280 }
281
282 protected:
283#if TVM_LOG_DEBUG
284 uint64_t state_marker;
285 /*! \brief Construct by value */
286 iterator(uint64_t index, const MapNode* self)
287 : state_marker(self->state_marker), index(index), self(self) {}
288
289#else
290 iterator(uint64_t index, const MapNode* self) : index(index), self(self) {}
291#endif // TVM_LOG_DEBUG
292 /*! \brief The position on the array */
293 uint64_t index;
294 /*! \brief The container it points to */
295 const MapNode* self;
296
297 friend class DenseMapNode;
298 friend class SmallMapNode;
299 };
300 /*!
301 * \brief Create an empty container
302 * \return The object created
303 */
304 static inline ObjectPtr<MapNode> Empty();
305
306 protected:
307#if TVM_LOG_DEBUG
308 uint64_t state_marker;
309#endif // TVM_LOG_DEBUG
310 /*!
311 * \brief Create the map using contents from the given iterators.
312 * \param first Begin of iterator
313 * \param last End of iterator
314 * \tparam IterType The type of iterator
315 * \return ObjectPtr to the map created
316 */
317 template <typename IterType>
318 static inline ObjectPtr<Object> CreateFromRange(IterType first, IterType last);
319 /*!
320 * \brief InsertMaybeReHash an entry into the given hash map
321 * \param kv The entry to be inserted
322 * \param map The pointer to the map, can be changed if re-hashing happens
323 */
324 static inline void InsertMaybeReHash(const KVType& kv, ObjectPtr<Object>* map);
325 /*!
326 * \brief Create an empty container with elements copying from another SmallMapNode
327 * \param from The source container
328 * \return The object created
329 */
330 static inline ObjectPtr<MapNode> CopyFrom(MapNode* from);
331 /*! \brief number of slots minus 1 */
332 uint64_t slots_;
333 /*! \brief number of entries in the container */
334 uint64_t size_;
335 // Reference class
336 template <typename, typename, typename, typename>
337 friend class Map;
338};
339
340/*! \brief A specialization of small-sized hash map */
341class SmallMapNode : public MapNode,
342 public runtime::InplaceArrayBase<SmallMapNode, MapNode::KVType> {
343 private:
344 static constexpr uint64_t kInitSize = 2;
345 static constexpr uint64_t kMaxSize = 4;
346
347 public:
348 using MapNode::iterator;
349 using MapNode::KVType;
350
351 /*! \brief Defaults to the destructor of InplaceArrayBase */
352 ~SmallMapNode() = default;
353 /*!
354 * \brief Count the number of times a key exists in the SmallMapNode
355 * \param key The indexing key
356 * \return The result, 0 or 1
357 */
358 size_t count(const key_type& key) const { return find(key).index < size_; }
359 /*!
360 * \brief Index value associated with a key, throw exception if the key does not exist
361 * \param key The indexing key
362 * \return The const reference to the value
363 */
364 const mapped_type& at(const key_type& key) const {
365 iterator itr = find(key);
366 ICHECK(itr.index < size_) << "IndexError: key is not in Map";
367 return itr->second;
368 }
369 /*!
370 * \brief Index value associated with a key, throw exception if the key does not exist
371 * \param key The indexing key
372 * \return The mutable reference to the value
373 */
374 mapped_type& at(const key_type& key) {
375 iterator itr = find(key);
376 ICHECK(itr.index < size_) << "IndexError: key is not in Map";
377 return itr->second;
378 }
379 /*! \return begin iterator */
380 iterator begin() const { return iterator(0, this); }
381 /*! \return end iterator */
382 iterator end() const { return iterator(size_, this); }
383 /*!
384 * \brief Index value associated with a key
385 * \param key The indexing key
386 * \return The iterator of the entry associated with the key, end iterator if not exists
387 */
388 iterator find(const key_type& key) const {
389 KVType* ptr = static_cast<KVType*>(AddressOf(0));
390 for (uint64_t i = 0; i < size_; ++i, ++ptr) {
391 if (ObjectEqual()(ptr->first, key)) {
392 return iterator(i, this);
393 }
394 }
395 return iterator(size_, this);
396 }
397 /*!
398 * \brief Erase the entry associated with the iterator
399 * \param position The iterator
400 */
401 void erase(const iterator& position) { Erase(position.index); }
402
403 private:
404 /*!
405 * \brief Remove a position in SmallMapNode
406 * \param index The position to be removed
407 */
408 void Erase(const uint64_t index) {
409 if (index >= size_) {
410 return;
411 }
412 KVType* begin = static_cast<KVType*>(AddressOf(0));
413 KVType* last = begin + (size_ - 1);
414 if (index + 1 == size_) {
415 last->first.ObjectRef::~ObjectRef();
416 last->second.ObjectRef::~ObjectRef();
417 } else {
418 *(begin + index) = std::move(*last);
419 }
420 size_ -= 1;
421 }
422 /*!
423 * \brief Create an empty container
424 * \param n Number of empty slots
425 * \return The object created
426 */
427 static ObjectPtr<SmallMapNode> Empty(uint64_t n = kInitSize) {
428 using ::tvm::runtime::make_inplace_array_object;
429 ObjectPtr<SmallMapNode> p = make_inplace_array_object<SmallMapNode, KVType>(n);
430 p->size_ = 0;
431 p->slots_ = n;
432 return p;
433 }
434 /*!
435 * \brief Create an empty container initialized with a given range
436 * \param n Number of empty slots
437 * \param first begin of iterator
438 * \param last end of iterator
439 * \tparam IterType The type of iterator
440 * \return The object created
441 */
442 template <typename IterType>
443 static ObjectPtr<SmallMapNode> CreateFromRange(uint64_t n, IterType first, IterType last) {
444 ObjectPtr<SmallMapNode> p = Empty(n);
445 KVType* ptr = static_cast<KVType*>(p->AddressOf(0));
446 for (; first != last; ++first, ++p->size_) {
447 new (ptr++) KVType(*first);
448 }
449 return p;
450 }
451 /*!
452 * \brief Create an empty container with elements copying from another SmallMapNode
453 * \param from The source container
454 * \return The object created
455 */
456 static ObjectPtr<SmallMapNode> CopyFrom(SmallMapNode* from) {
457 KVType* first = static_cast<KVType*>(from->AddressOf(0));
458 KVType* last = first + from->size_;
459 return CreateFromRange(from->size_, first, last);
460 }
461 /*!
462 * \brief InsertMaybeReHash an entry into the given hash map
463 * \param kv The entry to be inserted
464 * \param map The pointer to the map, can be changed if re-hashing happens
465 */
466 static void InsertMaybeReHash(const KVType& kv, ObjectPtr<Object>* map) {
467 SmallMapNode* map_node = static_cast<SmallMapNode*>(map->get());
468 iterator itr = map_node->find(kv.first);
469 if (itr.index < map_node->size_) {
470 itr->second = kv.second;
471 return;
472 }
473 if (map_node->size_ < map_node->slots_) {
474 KVType* ptr = static_cast<KVType*>(map_node->AddressOf(map_node->size_));
475 new (ptr) KVType(kv);
476 ++map_node->size_;
477 return;
478 }
479 uint64_t next_size = std::max(map_node->slots_ * 2, uint64_t(kInitSize));
480 next_size = std::min(next_size, uint64_t(kMaxSize));
481 ICHECK_GT(next_size, map_node->slots_);
482 ObjectPtr<Object> new_map = CreateFromRange(next_size, map_node->begin(), map_node->end());
483 InsertMaybeReHash(kv, &new_map);
484 *map = std::move(new_map);
485 }
486 /*!
487 * \brief Increment the pointer
488 * \param index The pointer to be incremented
489 * \return The increased pointer
490 */
491 uint64_t IncItr(uint64_t index) const { return index + 1 < size_ ? index + 1 : size_; }
492 /*!
493 * \brief Decrement the pointer
494 * \param index The pointer to be decremented
495 * \return The decreased pointer
496 */
497 uint64_t DecItr(uint64_t index) const { return index > 0 ? index - 1 : size_; }
498 /*!
499 * \brief De-reference the pointer
500 * \param index The pointer to be dereferenced
501 * \return The result
502 */
503 KVType* DeRefItr(uint64_t index) const { return static_cast<KVType*>(AddressOf(index)); }
504 /*! \brief A size function used by InplaceArrayBase */
505 uint64_t GetSize() const { return size_; }
506
507 protected:
508 friend class MapNode;
509 friend class DenseMapNode;
510 friend class runtime::InplaceArrayBase<SmallMapNode, MapNode::KVType>;
511};
512
513/*! \brief A specialization of hash map that implements the idea of array-based hash map.
514 * Another reference implementation can be found [1].
515 *
516 * A. Overview
517 *
518 * DenseMapNode did several improvements over traditional separate chaining hash,
519 * in terms of cache locality, memory footprints and data organization.
520 *
521 * A1. Implicit linked list. For better cache locality, instead of using linked list
522 * explicitly for each bucket, we store list data into a single array that spans contiguously
523 * in memory, and then carefully design access patterns to make sure most of them fall into
524 * a single cache line.
525 *
526 * A2. 1-byte metadata. There is only 1 byte overhead for each slot in the array to indexing and
527 * traversal. This can be divided in 3 parts.
528 * 1) Reserved code: (0b11111111)_2 indicates a slot is empty; (0b11111110)_2 indicates protected,
529 * which means the slot is empty but not allowed to be written.
530 * 2) If not empty or protected, the highest bit is used to indicate whether data in the slot is
531 * head of a linked list.
532 * 3) The rest 7 bits are used as the "next pointer" (i.e. pointer to the next element). On 64-bit
533 * architecture, an ordinary pointer can take up to 8 bytes, which is not acceptable overhead when
534 * dealing with 16-byte ObjectRef pairs. Based on a commonly noticed fact that the lists are
535 * relatively short (length <= 3) in hash maps, we follow [1]'s idea that only allows the pointer to
536 * be one of the 126 possible values, i.e. if the next element of i-th slot is (i + x)-th element,
537 * then x must be one of the 126 pre-defined values.
538 *
539 * A3. Data blocking. We organize the array in the way that every 16 elements forms a data block.
540 * The 16-byte metadata of those 16 elements are stored together, followed by the real data, i.e.
541 * 16 key-value pairs.
542 *
543 * B. Implementation details
544 *
545 * B1. Power-of-2 table size and Fibonacci Hashing. We use power-of-two as table size to avoid
546 * modulo for more efficient arithmetics. To make the hash-to-slot mapping distribute more evenly,
547 * we use the Fibonacci Hashing [2] trick.
548 *
549 * B2. Traverse a linked list in the array.
550 * 1) List head. Assume Fibonacci Hashing maps a given key to slot i, if metadata at slot i
551 * indicates that it is list head, then we found the head; otherwise the list is empty. No probing
552 * is done in this procedure. 2) Next element. To find the next element of a non-empty slot i, we
553 * look at the last 7 bits of the metadata at slot i. If they are all zeros, then it is the end of
554 * list; otherwise, we know that the next element is (i + candidates[the-last-7-bits]).
555 *
556 * B3. InsertMaybeReHash an element. Following B2, we first traverse the linked list to see if this
557 * element is in the linked list, and if not, we put it at the end by probing the next empty
558 * position in one of the 126 candidate positions. If the linked list does not even exist, but the
559 * slot for list head has been occupied by another linked list, we should find this intruder another
560 * place.
561 *
562 * B4. Quadratic probing with triangle numbers. In open address hashing, it is provable that probing
563 * with triangle numbers can traverse power-of-2-sized table [3]. In our algorithm, we follow the
564 * suggestion in [1] that also use triangle numbers for "next pointer" as well as sparing for list
565 * head.
566 *
567 * [1] https://github.com/skarupke/flat_hash_map
568 * [2] https://programmingpraxis.com/2018/06/19/fibonacci-hash/
569 * [3] https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/
570 */
571class DenseMapNode : public MapNode {
572 private:
573 /*! \brief The number of elements in a memory block */
574 static constexpr int kBlockCap = 16;
575 /*! \brief Maximum load factor of the hash map */
576 static constexpr double kMaxLoadFactor = 0.99;
577 /*! \brief Binary representation of the metadata of an empty slot */
578 static constexpr uint8_t kEmptySlot = uint8_t(0b11111111);
579 /*! \brief Binary representation of the metadata of a protected slot */
580 static constexpr uint8_t kProtectedSlot = uint8_t(0b11111110);
581 /*! \brief Number of probing choices available */
582 static constexpr int kNumJumpDists = 126;
583 /*! \brief Head of the implicit linked list */
584 struct ListNode;
585 /*! \brief POD type of a block of memory */
586 struct Block {
587 uint8_t bytes[kBlockCap + kBlockCap * sizeof(KVType)];
588 };
589 static_assert(sizeof(Block) == kBlockCap * (sizeof(KVType) + 1), "sizeof(Block) incorrect");
590 static_assert(std::is_standard_layout<Block>::value, "Block is not standard layout");
591
592 public:
593 using MapNode::iterator;
594
595 /*!
596 * \brief Destroy the DenseMapNode
597 */
598 ~DenseMapNode() { this->Reset(); }
599 /*! \return The number of elements of the key */
600 size_t count(const key_type& key) const { return !Search(key).IsNone(); }
601 /*!
602 * \brief Index value associated with a key, throw exception if the key does not exist
603 * \param key The indexing key
604 * \return The const reference to the value
605 */
606 const mapped_type& at(const key_type& key) const { return At(key); }
607 /*!
608 * \brief Index value associated with a key, throw exception if the key does not exist
609 * \param key The indexing key
610 * \return The mutable reference to the value
611 */
612 mapped_type& at(const key_type& key) { return At(key); }
613 /*!
614 * \brief Index value associated with a key
615 * \param key The indexing key
616 * \return The iterator of the entry associated with the key, end iterator if not exists
617 */
618 iterator find(const key_type& key) const {
619 ListNode node = Search(key);
620 return node.IsNone() ? end() : iterator(node.index, this);
621 }
622 /*!
623 * \brief Erase the entry associated with the iterator
624 * \param position The iterator
625 */
626 void erase(const iterator& position) {
627 uint64_t index = position.index;
628 if (position.self != nullptr && index <= this->slots_) {
629 Erase(ListNode(index, this));
630 }
631 }
632 /*! \return begin iterator */
633 iterator begin() const {
634 if (slots_ == 0) {
635 return iterator(0, this);
636 }
637 for (uint64_t index = 0; index <= slots_; ++index) {
638 if (!ListNode(index, this).IsEmpty()) {
639 return iterator(index, this);
640 }
641 }
642 return iterator(slots_ + 1, this);
643 }
644 /*! \return end iterator */
645 iterator end() const { return slots_ == 0 ? iterator(0, this) : iterator(slots_ + 1, this); }
646
647 private:
648 /*!
649 * \brief Search for the given key
650 * \param key The key
651 * \return ListNode that associated with the key
652 */
653 ListNode Search(const key_type& key) const {
654 if (this->size_ == 0) {
655 return ListNode();
656 }
657 for (ListNode iter = GetListHead(ObjectHash()(key)); !iter.IsNone(); iter.MoveToNext(this)) {
658 if (ObjectEqual()(key, iter.Key())) {
659 return iter;
660 }
661 }
662 return ListNode();
663 }
664 /*!
665 * \brief Search for the given key, throw exception if not exists
666 * \param key The key
667 * \return ListNode that associated with the key
668 */
669 mapped_type& At(const key_type& key) const {
670 ListNode iter = Search(key);
671 ICHECK(!iter.IsNone()) << "IndexError: key is not in Map";
672 return iter.Val();
673 }
674 /*!
675 * \brief Try to insert a key, or do nothing if already exists
676 * \param key The indexing key
677 * \param result The linked-list entry found or just constructed
678 * \return A boolean, indicating if actual insertion happens
679 */
680 bool TryInsert(const key_type& key, ListNode* result) {
681 if (slots_ == 0) {
682 return false;
683 }
684 // required that `iter` to be the head of a linked list through which we can iterator
685 ListNode iter = IndexFromHash(ObjectHash()(key));
686 // `iter` can be: 1) empty; 2) body of an irrelevant list; 3) head of the relevant list
687 // Case 1: empty
688 if (iter.IsEmpty()) {
689 iter.NewHead(KVType(key, ObjectRef(nullptr)));
690 this->size_ += 1;
691 *result = iter;
692 return true;
693 }
694 // Case 2: body of an irrelevant list
695 if (!iter.IsHead()) {
696 // we move the elements around and construct the single-element linked list
697 return IsFull() ? false : TrySpareListHead(iter, key, result);
698 }
699 // Case 3: head of the relevant list
700 // we iterate through the linked list until the end
701 // make sure `iter` is the previous element of `next`
702 ListNode next = iter;
703 do {
704 // find equal item, do not insert
705 if (ObjectEqual()(key, next.Key())) {
706 *result = next;
707 return true;
708 }
709 // make sure `iter` is the previous element of `next`
710 iter = next;
711 } while (next.MoveToNext(this));
712 // `iter` is the tail of the linked list
713 // always check capacity before insertion
714 if (IsFull()) {
715 return false;
716 }
717 // find the next empty slot
718 uint8_t jump;
719 if (!iter.GetNextEmpty(this, &jump, result)) {
720 return false;
721 }
722 result->NewTail(KVType(key, ObjectRef(nullptr)));
723 // link `iter` to `empty`, and move forward
724 iter.SetJump(jump);
725 this->size_ += 1;
726 return true;
727 }
728 /*!
729 * \brief Spare an entry to be the head of a linked list.
730 * As described in B3, during insertion, it is possible that the entire linked list does not
731 * exist, but the slot of its head has been occupied by other linked lists. In this case, we need
732 * to spare the slot by moving away the elements to another valid empty one to make insertion
733 * possible.
734 * \param target The given entry to be spared
735 * \param key The indexing key
736 * \param result The linked-list entry constructed as the head
737 * \return A boolean, if actual insertion happens
738 */
739 bool TrySpareListHead(ListNode target, const key_type& key, ListNode* result) {
740 // `target` is not the head of the linked list
741 // move the original item of `target` (if any)
742 // and construct new item on the position `target`
743 // To make `target` empty, we
744 // 1) find `w` the previous element of `target` in the linked list
745 // 2) copy the linked list starting from `r = target`
746 // 3) paste them after `w`
747 // read from the linked list after `r`
748 ListNode r = target;
749 // write to the tail of `w`
750 ListNode w = target.FindPrev(this);
751 // after `target` is moved, we disallow writing to the slot
752 bool is_first = true;
753 uint8_t r_meta, jump;
754 ListNode empty;
755 do {
756 // `jump` describes how `w` is jumped to `empty`
757 // rehash if there is no empty space after `w`
758 if (!w.GetNextEmpty(this, &jump, &empty)) {
759 return false;
760 }
761 // move `r` to `empty`
762 empty.NewTail(std::move(r.Data()));
763 // clear the metadata of `r`
764 r_meta = r.Meta();
765 if (is_first) {
766 is_first = false;
767 r.SetProtected();
768 } else {
769 r.SetEmpty();
770 }
771 // link `w` to `empty`, and move forward
772 w.SetJump(jump);
773 w = empty;
774 // move `r` forward as well
775 } while (r.MoveToNext(this, r_meta));
776 // finally we have done moving the linked list
777 // fill data_ into `target`
778 target.NewHead(KVType(key, ObjectRef(nullptr)));
779 this->size_ += 1;
780 *result = target;
781 return true;
782 }
783 /*!
784 * \brief Remove a ListNode
785 * \param iter The node to be removed
786 */
787 void Erase(const ListNode& iter) {
788 this->size_ -= 1;
789 if (!iter.HasNext()) {
790 // `iter` is the last
791 if (!iter.IsHead()) {
792 // cut the link if there is any
793 iter.FindPrev(this).SetJump(0);
794 }
795 iter.Data().KVType::~KVType();
796 iter.SetEmpty();
797 } else {
798 ListNode last = iter, prev = iter;
799 for (last.MoveToNext(this); last.HasNext(); prev = last, last.MoveToNext(this)) {
800 }
801 iter.Data() = std::move(last.Data());
802 last.SetEmpty();
803 prev.SetJump(0);
804 }
805 }
806 /*! \brief Clear the container to empty, release all entries and memory acquired */
807 void Reset() {
808 uint64_t n_blocks = CalcNumBlocks(this->slots_);
809 for (uint64_t bi = 0; bi < n_blocks; ++bi) {
810 uint8_t* meta_ptr = data_[bi].bytes;
811 KVType* data_ptr = reinterpret_cast<KVType*>(data_[bi].bytes + kBlockCap);
812 for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) {
813 uint8_t& meta = *meta_ptr;
814 if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) {
815 meta = uint8_t(kEmptySlot);
816 data_ptr->KVType::~KVType();
817 }
818 }
819 }
820 ReleaseMemory();
821 }
822 /*! \brief Release the memory acquired by the container without deleting its entries stored inside
823 */
824 void ReleaseMemory() {
825 delete[] data_;
826 data_ = nullptr;
827 slots_ = 0;
828 size_ = 0;
829 fib_shift_ = 63;
830 }
831 /*!
832 * \brief Create an empty container
833 * \param fib_shift The fib shift provided
834 * \param n_slots Number of slots required, should be power-of-two
835 * \return The object created
836 */
837 static ObjectPtr<DenseMapNode> Empty(uint32_t fib_shift, uint64_t n_slots) {
838 ICHECK_GT(n_slots, uint64_t(SmallMapNode::kMaxSize));
839 ObjectPtr<DenseMapNode> p = make_object<DenseMapNode>();
840 uint64_t n_blocks = CalcNumBlocks(n_slots - 1);
841 Block* block = p->data_ = new Block[n_blocks];
842 p->slots_ = n_slots - 1;
843 p->size_ = 0;
844 p->fib_shift_ = fib_shift;
845 for (uint64_t i = 0; i < n_blocks; ++i, ++block) {
846 std::fill(block->bytes, block->bytes + kBlockCap, uint8_t(kEmptySlot));
847 }
848 return p;
849 }
850 /*!
851 * \brief Create an empty container with elements copying from another DenseMapNode
852 * \param from The source container
853 * \return The object created
854 */
855 static ObjectPtr<DenseMapNode> CopyFrom(DenseMapNode* from) {
856 ObjectPtr<DenseMapNode> p = make_object<DenseMapNode>();
857 uint64_t n_blocks = CalcNumBlocks(from->slots_);
858 p->data_ = new Block[n_blocks];
859 p->slots_ = from->slots_;
860 p->size_ = from->size_;
861 p->fib_shift_ = from->fib_shift_;
862 for (uint64_t bi = 0; bi < n_blocks; ++bi) {
863 uint8_t* meta_ptr_from = from->data_[bi].bytes;
864 KVType* data_ptr_from = reinterpret_cast<KVType*>(from->data_[bi].bytes + kBlockCap);
865 uint8_t* meta_ptr_to = p->data_[bi].bytes;
866 KVType* data_ptr_to = reinterpret_cast<KVType*>(p->data_[bi].bytes + kBlockCap);
867 for (int j = 0; j < kBlockCap;
868 ++j, ++meta_ptr_from, ++data_ptr_from, ++meta_ptr_to, ++data_ptr_to) {
869 uint8_t& meta = *meta_ptr_to = *meta_ptr_from;
870 ICHECK(meta != kProtectedSlot);
871 if (meta != uint8_t(kEmptySlot)) {
872 new (data_ptr_to) KVType(*data_ptr_from);
873 }
874 }
875 }
876 return p;
877 }
878 /*!
879 * \brief InsertMaybeReHash an entry into the given hash map
880 * \param kv The entry to be inserted
881 * \param map The pointer to the map, can be changed if re-hashing happens
882 */
883 static void InsertMaybeReHash(const KVType& kv, ObjectPtr<Object>* map) {
884 DenseMapNode* map_node = static_cast<DenseMapNode*>(map->get());
885 ListNode iter;
886 // Try to insert. If succeed, we simply return
887 if (map_node->TryInsert(kv.first, &iter)) {
888 iter.Val() = kv.second;
889 return;
890 }
891 ICHECK_GT(map_node->slots_, uint64_t(SmallMapNode::kMaxSize));
892 // Otherwise, start rehash
893 ObjectPtr<Object> p = Empty(map_node->fib_shift_ - 1, map_node->slots_ * 2 + 2);
894 // Insert the given `kv` into the new hash map
895 InsertMaybeReHash(kv, &p);
896 uint64_t n_blocks = CalcNumBlocks(map_node->slots_);
897 // Then Insert data from the original block.
898 for (uint64_t bi = 0; bi < n_blocks; ++bi) {
899 uint8_t* meta_ptr = map_node->data_[bi].bytes;
900 KVType* data_ptr = reinterpret_cast<KVType*>(map_node->data_[bi].bytes + kBlockCap);
901 for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) {
902 uint8_t& meta = *meta_ptr;
903 if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) {
904 meta = uint8_t(kEmptySlot);
905 KVType kv = std::move(*data_ptr);
906 InsertMaybeReHash(kv, &p);
907 }
908 }
909 }
910 map_node->ReleaseMemory();
911 *map = p;
912 }
913 /*!
914 * \brief Check whether the hash table is full
915 * \return A boolean indicating whether hash table is full
916 */
917 bool IsFull() const { return size_ + 1 > (slots_ + 1) * kMaxLoadFactor; }
918 /*!
919 * \brief Increment the pointer
920 * \param index The pointer to be incremented
921 * \return The increased pointer
922 */
923 uint64_t IncItr(uint64_t index) const {
924 for (++index; index <= slots_; ++index) {
925 if (!ListNode(index, this).IsEmpty()) {
926 return index;
927 }
928 }
929 return slots_ + 1;
930 }
931 /*!
932 * \brief Decrement the pointer
933 * \param index The pointer to be decremented
934 * \return The decreased pointer
935 */
936 uint64_t DecItr(uint64_t index) const {
937 while (index != 0) {
938 index -= 1;
939 if (!ListNode(index, this).IsEmpty()) {
940 return index;
941 }
942 }
943 return slots_ + 1;
944 }
945 /*!
946 * \brief De-reference the pointer
947 * \param index The pointer to be dereferenced
948 * \return The result
949 */
950 KVType* DeRefItr(uint64_t index) const { return &ListNode(index, this).Data(); }
951 /*! \brief Construct from hash code */
952 ListNode IndexFromHash(uint64_t hash_value) const {
953 return ListNode(FibHash(hash_value, fib_shift_), this);
954 }
955 /*! \brief Construct from hash code if the position is head of list */
956 ListNode GetListHead(uint64_t hash_value) const {
957 ListNode node = IndexFromHash(hash_value);
958 return node.IsHead() ? node : ListNode();
959 }
960 /*! \brief Construct the number of blocks in the hash table */
961 static uint64_t CalcNumBlocks(uint64_t n_slots_m1) {
962 uint64_t n_slots = n_slots_m1 > 0 ? n_slots_m1 + 1 : 0;
963 return (n_slots + kBlockCap - 1) / kBlockCap;
964 }
965 /*!
966 * \brief Calculate the power-of-2 table size given the lower-bound of required capacity.
967 * \param cap The lower-bound of the required capacity
968 * \param fib_shift The result shift for Fibonacci Hashing
969 * \param n_slots The result number of slots
970 */
971 static void CalcTableSize(uint64_t cap, uint32_t* fib_shift, uint64_t* n_slots) {
972 uint32_t shift = 64;
973 uint64_t slots = 1;
974 for (uint64_t c = cap; c; c >>= 1) {
975 shift -= 1;
976 slots <<= 1;
977 }
978 ICHECK_GT(slots, cap);
979 if (slots < cap * 2) {
980 *fib_shift = shift - 1;
981 *n_slots = slots << 1;
982 } else {
983 *fib_shift = shift;
984 *n_slots = slots;
985 }
986 }
987 /*!
988 * \brief Fibonacci Hashing, maps a hash code to an index in a power-of-2-sized table.
989 * See also: https://programmingpraxis.com/2018/06/19/fibonacci-hash/.
990 * \param hash_value The raw hash value
991 * \param fib_shift The shift in Fibonacci Hashing
992 * \return An index calculated using Fibonacci Hashing
993 */
994 static uint64_t FibHash(uint64_t hash_value, uint32_t fib_shift) {
995 constexpr uint64_t coeff = 11400714819323198485ull;
996 return (coeff * hash_value) >> fib_shift;
997 }
998 /*! \brief The implicit in-place linked list used to index a chain */
999 struct ListNode {
1000 /*! \brief Construct None */
1001 ListNode() : index(0), block(nullptr) {}
1002 /*! \brief Construct from position */
1003 ListNode(uint64_t index, const DenseMapNode* self)
1004 : index(index), block(self->data_ + (index / kBlockCap)) {}
1005 /*! \brief Metadata on the entry */
1006 uint8_t& Meta() const { return *(block->bytes + index % kBlockCap); }
1007 /*! \brief Data on the entry */
1008 KVType& Data() const {
1009 return *(reinterpret_cast<KVType*>(block->bytes + kBlockCap +
1010 (index % kBlockCap) * sizeof(KVType)));
1011 }
1012 /*! \brief Key on the entry */
1013 key_type& Key() const { return Data().first; }
1014 /*! \brief Value on the entry */
1015 mapped_type& Val() const { return Data().second; }
1016 /*! \brief If the entry is head of linked list */
1017 bool IsHead() const { return (Meta() & 0b10000000) == 0b00000000; }
1018 /*! \brief If the entry is none */
1019 bool IsNone() const { return block == nullptr; }
1020 /*! \brief If the entry is empty slot */
1021 bool IsEmpty() const { return Meta() == uint8_t(kEmptySlot); }
1022 /*! \brief If the entry is protected slot */
1023 bool IsProtected() const { return Meta() == uint8_t(kProtectedSlot); }
1024 /*! \brief Set the entry to be empty */
1025 void SetEmpty() const { Meta() = uint8_t(kEmptySlot); }
1026 /*! \brief Set the entry to be protected */
1027 void SetProtected() const { Meta() = uint8_t(kProtectedSlot); }
1028 /*! \brief Set the entry's jump to its next entry */
1029 void SetJump(uint8_t jump) const { (Meta() &= 0b10000000) |= jump; }
1030 /*! \brief Construct a head of linked list in-place */
1031 void NewHead(KVType v) const {
1032 Meta() = 0b00000000;
1033 new (&Data()) KVType(std::move(v));
1034 }
1035 /*! \brief Construct a tail of linked list in-place */
1036 void NewTail(KVType v) const {
1037 Meta() = 0b10000000;
1038 new (&Data()) KVType(std::move(v));
1039 }
1040 /*! \brief If the entry has next entry on the linked list */
1041 bool HasNext() const { return NextProbeLocation(Meta() & 0b01111111) != 0; }
1042 /*! \brief Move the entry to the next entry on the linked list */
1043 bool MoveToNext(const DenseMapNode* self, uint8_t meta) {
1044 uint64_t offset = NextProbeLocation(meta & 0b01111111);
1045 if (offset == 0) {
1046 index = 0;
1047 block = nullptr;
1048 return false;
1049 }
1050 index = (index + offset) & (self->slots_);
1051 block = self->data_ + (index / kBlockCap);
1052 return true;
1053 }
1054 /*! \brief Move the entry to the next entry on the linked list */
1055 bool MoveToNext(const DenseMapNode* self) { return MoveToNext(self, Meta()); }
1056 /*! \brief Get the previous entry on the linked list */
1057 ListNode FindPrev(const DenseMapNode* self) const {
1058 // start from the head of the linked list, which must exist
1059 ListNode next = self->IndexFromHash(ObjectHash()(Key()));
1060 // `prev` is always the previous item of `next`
1061 ListNode prev = next;
1062 for (next.MoveToNext(self); index != next.index; prev = next, next.MoveToNext(self)) {
1063 }
1064 return prev;
1065 }
1066 /*! \brief Get the next empty jump */
1067 bool GetNextEmpty(const DenseMapNode* self, uint8_t* jump, ListNode* result) const {
1068 for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) {
1069 ListNode candidate((index + NextProbeLocation(idx)) & (self->slots_), self);
1070 if (candidate.IsEmpty()) {
1071 *jump = idx;
1072 *result = candidate;
1073 return true;
1074 }
1075 }
1076 return false;
1077 }
1078 /*! \brief Index on the real array */
1079 uint64_t index;
1080 /*! \brief Pointer to the actual block */
1081 Block* block;
1082 };
1083
1084 protected:
1085 /*! \brief fib shift in Fibonacci Hashing */
1086 uint32_t fib_shift_;
1087 /*! \brief array of data blocks */
1088 Block* data_;
1089 static uint64_t NextProbeLocation(size_t index) {
1090 /* clang-format off */
1091 /*! \brief Candidates of probing distance */
1092 static const uint64_t kNextProbeLocation[kNumJumpDists] {
1093 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
1094 // Quadratic probing with triangle numbers. See also:
1095 // 1) https://en.wikipedia.org/wiki/Quadratic_probing
1096 // 2) https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/
1097 // 3) https://github.com/skarupke/flat_hash_map
1098 21, 28, 36, 45, 55, 66, 78, 91, 105, 120,
1099 136, 153, 171, 190, 210, 231, 253, 276, 300, 325,
1100 351, 378, 406, 435, 465, 496, 528, 561, 595, 630,
1101 666, 703, 741, 780, 820, 861, 903, 946, 990, 1035,
1102 1081, 1128, 1176, 1225, 1275, 1326, 1378, 1431, 1485, 1540,
1103 1596, 1653, 1711, 1770, 1830, 1891, 1953, 2016, 2080, 2145,
1104 2211, 2278, 2346, 2415, 2485, 2556, 2628,
1105 // larger triangle numbers
1106 8515, 19110, 42778, 96141, 216153,
1107 486591, 1092981, 2458653, 5532801, 12442566,
1108 27993903, 62983476, 141717030, 318844378, 717352503,
1109 1614057336, 3631522476, 8170957530, 18384510628, 41364789378,
1110 93070452520, 209408356380, 471168559170, 1060128894105, 2385289465695,
1111 5366898840628, 12075518705635, 27169915244790, 61132312065111, 137547689707000,
1112 309482283181501, 696335127828753, 1566753995631385, 3525196511162271, 7931691992677701,
1113 17846306936293605, 40154190677507445, 90346928918121501, 203280589587557251,
1114 457381325854679626, 1029107982097042876, 2315492959180353330, 5209859154120846435,
1115 };
1116 /* clang-format on */
1117 return kNextProbeLocation[index];
1118 }
1119 friend class MapNode;
1120};
1121
1122#define TVM_DISPATCH_MAP(base, var, body) \
1123 { \
1124 using TSmall = SmallMapNode*; \
1125 using TDense = DenseMapNode*; \
1126 uint64_t slots = base->slots_; \
1127 if (slots <= SmallMapNode::kMaxSize) { \
1128 TSmall var = static_cast<TSmall>(base); \
1129 body; \
1130 } else { \
1131 TDense var = static_cast<TDense>(base); \
1132 body; \
1133 } \
1134 }
1135
1136#define TVM_DISPATCH_MAP_CONST(base, var, body) \
1137 { \
1138 using TSmall = const SmallMapNode*; \
1139 using TDense = const DenseMapNode*; \
1140 uint64_t slots = base->slots_; \
1141 if (slots <= SmallMapNode::kMaxSize) { \
1142 TSmall var = static_cast<TSmall>(base); \
1143 body; \
1144 } else { \
1145 TDense var = static_cast<TDense>(base); \
1146 body; \
1147 } \
1148 }
1149
1150inline MapNode::iterator::pointer MapNode::iterator::operator->() const {
1151 TVM_MAP_FAIL_IF_CHANGED()
1152 TVM_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); });
1153}
1154
1155inline MapNode::iterator& MapNode::iterator::operator++() {
1156 TVM_MAP_FAIL_IF_CHANGED()
1157 TVM_DISPATCH_MAP_CONST(self, p, {
1158 index = p->IncItr(index);
1159 return *this;
1160 });
1161}
1162
1163inline MapNode::iterator& MapNode::iterator::operator--() {
1164 TVM_MAP_FAIL_IF_CHANGED()
1165 TVM_DISPATCH_MAP_CONST(self, p, {
1166 index = p->DecItr(index);
1167 return *this;
1168 });
1169}
1170
1171inline size_t MapNode::count(const key_type& key) const {
1172 TVM_DISPATCH_MAP_CONST(this, p, { return p->count(key); });
1173}
1174
1175inline const MapNode::mapped_type& MapNode::at(const MapNode::key_type& key) const {
1176 TVM_DISPATCH_MAP_CONST(this, p, { return p->at(key); });
1177}
1178
1179inline MapNode::mapped_type& MapNode::at(const MapNode::key_type& key) {
1180 TVM_DISPATCH_MAP(this, p, { return p->at(key); });
1181}
1182
1183inline MapNode::iterator MapNode::begin() const {
1184 TVM_DISPATCH_MAP_CONST(this, p, { return p->begin(); });
1185}
1186
1187inline MapNode::iterator MapNode::end() const {
1188 TVM_DISPATCH_MAP_CONST(this, p, { return p->end(); });
1189}
1190
1191inline MapNode::iterator MapNode::find(const MapNode::key_type& key) const {
1192 TVM_DISPATCH_MAP_CONST(this, p, { return p->find(key); });
1193}
1194
1195inline void MapNode::erase(const MapNode::iterator& position) {
1196 TVM_DISPATCH_MAP(this, p, { return p->erase(position); });
1197}
1198
1199#undef TVM_DISPATCH_MAP
1200#undef TVM_DISPATCH_MAP_CONST
1201
1202inline ObjectPtr<MapNode> MapNode::Empty() { return SmallMapNode::Empty(); }
1203
1204inline ObjectPtr<MapNode> MapNode::CopyFrom(MapNode* from) {
1205 if (from->slots_ <= SmallMapNode::kMaxSize) {
1206 return SmallMapNode::CopyFrom(static_cast<SmallMapNode*>(from));
1207 } else {
1208 return DenseMapNode::CopyFrom(static_cast<DenseMapNode*>(from));
1209 }
1210}
1211
1212template <typename IterType>
1213inline ObjectPtr<Object> MapNode::CreateFromRange(IterType first, IterType last) {
1214 int64_t _cap = std::distance(first, last);
1215 if (_cap < 0) {
1216 return SmallMapNode::Empty();
1217 }
1218 uint64_t cap = static_cast<uint64_t>(_cap);
1219 if (cap < SmallMapNode::kMaxSize) {
1220 return SmallMapNode::CreateFromRange(cap, first, last);
1221 }
1222 uint32_t fib_shift;
1223 uint64_t n_slots;
1224 DenseMapNode::CalcTableSize(cap, &fib_shift, &n_slots);
1225 ObjectPtr<Object> obj = DenseMapNode::Empty(fib_shift, n_slots);
1226 for (; first != last; ++first) {
1227 KVType kv(*first);
1228 DenseMapNode::InsertMaybeReHash(kv, &obj);
1229 }
1230 return obj;
1231}
1232
1233inline void MapNode::InsertMaybeReHash(const KVType& kv, ObjectPtr<Object>* map) {
1234 constexpr uint64_t kSmallMapMaxSize = SmallMapNode::kMaxSize;
1235 MapNode* base = static_cast<MapNode*>(map->get());
1236#if TVM_LOG_DEBUG
1237 base->state_marker++;
1238#endif // TVM_LOG_DEBUG
1239 if (base->slots_ < kSmallMapMaxSize) {
1240 SmallMapNode::InsertMaybeReHash(kv, map);
1241 } else if (base->slots_ == kSmallMapMaxSize) {
1242 if (base->size_ < base->slots_) {
1243 SmallMapNode::InsertMaybeReHash(kv, map);
1244 } else {
1245 ObjectPtr<Object> new_map = MapNode::CreateFromRange(base->begin(), base->end());
1246 DenseMapNode::InsertMaybeReHash(kv, &new_map);
1247 *map = std::move(new_map);
1248 }
1249 } else {
1250 DenseMapNode::InsertMaybeReHash(kv, map);
1251 }
1252}
1253
1254template <>
1255inline ObjectPtr<MapNode> make_object<>() = delete;
1256
1257#endif
1258
1259/*!
1260 * \brief Map container of NodeRef->NodeRef in DSL graph.
1261 * Map implements copy on write semantics, which means map is mutable
1262 * but copy will happen when array is referenced in more than two places.
1263 *
1264 * operator[] only provide const acces, use Set to mutate the content.
1265 * \tparam K The key NodeRef type.
1266 * \tparam V The value NodeRef type.
1267 */
1268template <typename K, typename V,
1269 typename = typename std::enable_if<std::is_base_of<ObjectRef, K>::value>::type,
1270 typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
1271class Map : public ObjectRef {
1272 public:
1273 using key_type = K;
1274 using mapped_type = V;
1275 class iterator;
1276 /*!
1277 * \brief default constructor
1278 */
1279 Map() { data_ = MapNode::Empty(); }
1280 /*!
1281 * \brief move constructor
1282 * \param other source
1283 */
1284 Map(Map<K, V>&& other) { data_ = std::move(other.data_); }
1285 /*!
1286 * \brief copy constructor
1287 * \param other source
1288 */
1289 Map(const Map<K, V>& other) : ObjectRef(other.data_) {}
1290 /*!
1291 * \brief copy assign operator
1292 * \param other The source of assignment
1293 * \return reference to self.
1294 */
1295 Map<K, V>& operator=(Map<K, V>&& other) {
1296 data_ = std::move(other.data_);
1297 return *this;
1298 }
1299 /*!
1300 * \brief move assign operator
1301 * \param other The source of assignment
1302 * \return reference to self.
1303 */
1304 Map<K, V>& operator=(const Map<K, V>& other) {
1305 data_ = other.data_;
1306 return *this;
1307 }
1308 /*!
1309 * \brief constructor from pointer
1310 * \param n the container pointer
1311 */
1312 explicit Map(ObjectPtr<Object> n) : ObjectRef(n) {}
1313 /*!
1314 * \brief constructor from iterator
1315 * \param begin begin of iterator
1316 * \param end end of iterator
1317 * \tparam IterType The type of iterator
1318 */
1319 template <typename IterType>
1320 Map(IterType begin, IterType end) {
1321 data_ = MapNode::CreateFromRange(begin, end);
1322 }
1323 /*!
1324 * \brief constructor from initializer list
1325 * \param init The initalizer list
1326 */
1327 Map(std::initializer_list<std::pair<K, V>> init) {
1328 data_ = MapNode::CreateFromRange(init.begin(), init.end());
1329 }
1330 /*!
1331 * \brief constructor from unordered_map
1332 * \param init The unordered_map
1333 */
1334 template <typename Hash, typename Equal>
1335 Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
1336 data_ = MapNode::CreateFromRange(init.begin(), init.end());
1337 }
1338 /*!
1339 * \brief Read element from map.
1340 * \param key The key
1341 * \return the corresonding element.
1342 */
1343 const V at(const K& key) const { return DowncastNoCheck<V>(GetMapNode()->at(key)); }
1344 /*!
1345 * \brief Read element from map.
1346 * \param key The key
1347 * \return the corresonding element.
1348 */
1349 const V operator[](const K& key) const { return this->at(key); }
1350 /*! \return The size of the array */
1351 size_t size() const {
1352 MapNode* n = GetMapNode();
1353 return n == nullptr ? 0 : n->size();
1354 }
1355 /*! \return The number of elements of the key */
1356 size_t count(const K& key) const {
1357 MapNode* n = GetMapNode();
1358 return n == nullptr ? 0 : GetMapNode()->count(key);
1359 }
1360 /*! \return whether array is empty */
1361 bool empty() const { return size() == 0; }
1362 /*! \brief Release reference to all the elements */
1363 void clear() {
1364 MapNode* n = GetMapNode();
1365 if (n != nullptr) {
1366 data_ = MapNode::Empty();
1367 }
1368 }
1369 /*!
1370 * \brief set the Map.
1371 * \param key The index key.
1372 * \param value The value to be setted.
1373 */
1374 void Set(const K& key, const V& value) {
1375 CopyOnWrite();
1376 MapNode::InsertMaybeReHash(MapNode::KVType(key, value), &data_);
1377 }
1378 /*! \return begin iterator */
1379 iterator begin() const { return iterator(GetMapNode()->begin()); }
1380 /*! \return end iterator */
1381 iterator end() const { return iterator(GetMapNode()->end()); }
1382 /*! \return find the key and returns the associated iterator */
1383 iterator find(const K& key) const { return iterator(GetMapNode()->find(key)); }
1384 /*! \return The value associated with the key, NullOpt if not found */
1385 Optional<V> Get(const K& key) const {
1386 MapNode::iterator iter = GetMapNode()->find(key);
1387 if (iter == GetMapNode()->end()) {
1388 return NullOptType{};
1389 }
1390 return DowncastNoCheck<V>(iter->second);
1391 }
1392 void erase(const K& key) { CopyOnWrite()->erase(key); }
1393
1394 /*!
1395 * \brief copy on write semantics
1396 * Do nothing if current handle is the unique copy of the array.
1397 * Otherwise make a new copy of the array to ensure the current handle
1398 * hold a unique copy.
1399 *
1400 * \return Handle to the internal node container(which guarantees to be unique)
1401 */
1402 MapNode* CopyOnWrite() {
1403 if (data_.get() == nullptr) {
1404 data_ = MapNode::Empty();
1405 } else if (!data_.unique()) {
1406 data_ = MapNode::CopyFrom(GetMapNode());
1407 }
1408 return GetMapNode();
1409 }
1410 /*! \brief specify container node */
1411 using ContainerType = MapNode;
1412
1413 /*! \brief Iterator of the hash map */
1414 class iterator {
1415 public:
1416 using iterator_category = std::bidirectional_iterator_tag;
1417 using difference_type = int64_t;
1418 using value_type = const std::pair<K, V>;
1419 using pointer = value_type*;
1420 using reference = value_type;
1421
1422 iterator() : itr() {}
1423
1424 /*! \brief Compare iterators */
1425 bool operator==(const iterator& other) const { return itr == other.itr; }
1426 /*! \brief Compare iterators */
1427 bool operator!=(const iterator& other) const { return itr != other.itr; }
1428 /*! \brief De-reference iterators is not allowed */
1429 pointer operator->() const = delete;
1430 /*! \brief De-reference iterators */
1431 reference operator*() const {
1432 auto& kv = *itr;
1433 return std::make_pair(DowncastNoCheck<K>(kv.first), DowncastNoCheck<V>(kv.second));
1434 }
1435 /*! \brief Prefix self increment, e.g. ++iter */
1436 iterator& operator++() {
1437 ++itr;
1438 return *this;
1439 }
1440 /*! \brief Suffix self increment */
1441 iterator operator++(int) {
1442 iterator copy = *this;
1443 ++(*this);
1444 return copy;
1445 }
1446
1447 private:
1448 iterator(const MapNode::iterator& itr) // NOLINT(*)
1449 : itr(itr) {}
1450
1451 template <typename, typename, typename, typename>
1452 friend class Map;
1453
1454 MapNode::iterator itr;
1455 };
1456
1457 private:
1458 /*! \brief Return data_ as type of pointer of MapNode */
1459 MapNode* GetMapNode() const { return static_cast<MapNode*>(data_.get()); }
1460};
1461
1462/*!
1463 * \brief Merge two Maps.
1464 * \param lhs the first Map to merge.
1465 * \param rhs the second Map to merge.
1466 * @return The merged Array. Original Maps are kept unchanged.
1467 */
1468template <typename K, typename V,
1469 typename = typename std::enable_if<std::is_base_of<ObjectRef, K>::value>::type,
1470 typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
1471inline Map<K, V> Merge(Map<K, V> lhs, const Map<K, V>& rhs) {
1472 for (const auto& p : rhs) {
1473 lhs.Set(p.first, p.second);
1474 }
1475 return std::move(lhs);
1476}
1477
1478} // namespace runtime
1479
1480// expose the functions to the root namespace.
1481using runtime::Map;
1482using runtime::MapNode;
1483} // namespace tvm
1484
1485#endif // TVM_RUNTIME_CONTAINER_MAP_H_
1486