1#pragma once
2
3#include <cstdint>
4#include <initializer_list>
5#include <string>
6#include <unordered_map>
7#include <utility>
8#include <vector>
9
10namespace torch {
11/// An ordered dictionary implementation, akin to Python's `OrderedDict`.
12template <typename Key, typename Value>
13class OrderedDict {
14 public:
15 /// A (key, value) pair.
16 class Item;
17
18 // The lifetime of an iterator is bound to the lifetime of the `OrderedDict`.
19 // Further, any `insert()` operation may invalidate all iterators
20 // pointing into the vector.
21 using Iterator = typename std::vector<Item>::iterator;
22 using ConstIterator = typename std::vector<Item>::const_iterator;
23
24 /// Constructs the `OrderedDict` with a short description of the kinds of keys
25 /// stored in the `OrderedDict`. This description is used in error messages
26 /// thrown by the `OrderedDict`.
27 explicit OrderedDict(std::string key_description = "Key");
28
29 /// Copy constructs this `OrderedDict` from `other`.
30 OrderedDict(const OrderedDict& other);
31
32 /// Assigns items from `other` to this `OrderedDict`.
33 OrderedDict& operator=(const OrderedDict& other);
34
35 // NB: Move works by default, because you can move-construct vectors of const
36 // values. I tried to make this noexcept (conditional on the move constructors
37 // of index_ and items_ being noexcept) but the obvious spelling didn't
38 // compile on Windows.
39 OrderedDict(OrderedDict&& other) = default;
40 OrderedDict& operator=(OrderedDict&& other) = default;
41
42 ~OrderedDict() = default;
43
44 /// Constructs a new `OrderedDict` and pre-populates it with the given
45 /// `Item`s.
46 /*implicit */ OrderedDict(std::initializer_list<Item> initializer_list);
47
48 /// Returns the key description string the `OrderedDict` was constructed with.
49 const std::string& key_description() const noexcept;
50
51 // Element Access
52
53 /// Returns the very first item in the `OrderedDict` and throws an exception
54 /// if it is empty.
55 Item& front();
56
57 /// Returns the very first item in the `OrderedDict` and throws an exception
58 /// if it is empty.
59 const Item& front() const;
60
61 /// Returns the very last item in the `OrderedDict` and throws an exception
62 /// if it is empty.
63 Item& back();
64
65 /// Returns the very last item in the `OrderedDict` and throws an exception
66 /// if it is empty.
67 const Item& back() const;
68
69 /// Returns the item at the `index`-th position in the `OrderedDict`. Throws
70 /// an exception if the index is out of bounds.
71 Item& operator[](size_t index);
72
73 /// Returns the item at the `index`-th position in the `OrderedDict`. Throws
74 /// an exception if the index is out of bounds.
75 const Item& operator[](size_t index) const;
76
77 /// Returns the value associated with the given `key`. Throws an exception if
78 /// no such key is stored in the `OrderedDict`. Use `find()` for a
79 /// non-throwing way of accessing a value if it is present.
80 Value& operator[](const Key& key);
81
82 /// Returns the value associated with the given `key`. Throws an exception if
83 /// no such key is stored in the `OrderedDict`. Use `find()` for a
84 /// non-throwing way of accessing a value if it is present.
85 const Value& operator[](const Key& key) const;
86
87 // Lookup
88
89 /// Returns a pointer to the value associated with the given key, or a
90 /// `nullptr` if no such key is stored in the `OrderedDict`.
91 Value* find(const Key& key) noexcept;
92
93 /// Returns a pointer to the value associated with the given key, or a
94 /// `nullptr` if no such key is stored in the `OrderedDict`.
95 const Value* find(const Key& key) const noexcept;
96
97 /// Returns true if the key is present in the `OrderedDict`.
98 bool contains(const Key& key) const noexcept;
99
100 // Iterators
101
102 /// Returns an iterator to the first item in the `OrderedDict`. Iteration is
103 /// ordered.
104 Iterator begin();
105
106 /// Returns an iterator to the first item in the `OrderedDict`. Iteration is
107 /// ordered.
108 ConstIterator begin() const;
109
110 /// Returns an iterator one past the last item in the `OrderedDict`.
111 Iterator end();
112
113 /// Returns an iterator one past the last item in the `OrderedDict`.
114 ConstIterator end() const;
115
116 // Capacity
117
118 /// Returns the number of items currently stored in the `OrderedDict`.
119 size_t size() const noexcept;
120
121 /// Returns true if the `OrderedDict` contains no elements.
122 bool is_empty() const noexcept;
123
124 /// Resizes internal storage to fit at least `requested_capacity` items
125 /// without requiring reallocation.
126 void reserve(size_t requested_capacity);
127
128 // Modifiers
129
130 /// Inserts a new `(key, value)` pair into the `OrderedDict`. Throws an
131 /// exception if the key is already present. If insertion is successful,
132 /// immediately returns a reference to the inserted value.
133 template <typename K, typename V>
134 Value& insert(K&& key, V&& value);
135
136 /// Inserts a new `(key, value)` pair into the `OrderedDict`. Throws an
137 /// exception if the key is already present. If insertion is successful,
138 /// immediately returns a reference to the inserted value.
139 Value& insert(Key key, Value&& value);
140
141 /// Inserts all items from `other` into this `OrderedDict`. If any key from
142 /// `other` is already present in this `OrderedDict`, an exception is thrown.
143 void update(OrderedDict&& other);
144
145 /// Inserts all items from `other` into this `OrderedDict`. If any key from
146 /// `other` is already present in this `OrderedDict`, an exception is thrown.
147 void update(const OrderedDict& other);
148
149 /// Removes the item that has `key` from this `OrderedDict` if exists and if
150 /// it doesn't an exception is thrown.
151 void erase(const Key& key);
152
153 /// Removes all items from this `OrderedDict`.
154 void clear();
155
156 // Observers
157
158 /// Returns the items stored in the `OrderedDict`.
159 const std::vector<Item>& items() const noexcept;
160
161 /// Returns a newly allocated vector and copies all keys from this
162 /// `OrderedDict` into the vector.
163 ::std::vector<Key> keys() const;
164
165 /// Returns a newly allocated vector and copies all values from this
166 /// `OrderedDict` into the vector.
167 ::std::vector<Value> values() const;
168
169 /// Returns a newly allocated vector and copies all keys and values from this
170 /// `OrderedDict` into a vector of `std::pair<Key, Value>`.
171 ::std::vector<std::pair<Key, Value>> pairs() const;
172
173 /// Returns true if both dicts contain the same keys and values, in the same
174 /// order.
175 template <typename K, typename V>
176 friend bool operator==(
177 const OrderedDict<K, V>& a,
178 const OrderedDict<K, V>& b);
179
180 private:
181 /// A mapping from a key to an index into the `items_` vector.
182 ::std::unordered_map<Key, size_t> index_;
183
184 /// The items stored in the `OrderedDict`.
185 ::std::vector<Item> items_;
186
187 /// A description of the keys stored in the `OrderedDict`.
188 ::std::string key_description_{"Key"};
189};
190
191// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ OrderedDict::Item ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
192
193template <typename Key, typename Value>
194class OrderedDict<Key, Value>::Item {
195 public:
196 /// Constructs a new item.
197 Item(Key key, Value value) : pair_(std::move(key), std::move(value)) {}
198
199 /// Returns a reference to the value.
200 Value& operator*() {
201 return value();
202 }
203
204 /// Returns a reference to the value.
205 const Value& operator*() const {
206 return value();
207 }
208
209 /// Allows access to the value using the arrow operator.
210 Value* operator->() {
211 return &value();
212 }
213
214 /// Allows access to the value using the arrow operator.
215 const Value* operator->() const {
216 return &value();
217 }
218
219 /// Returns a reference to the key.
220 const Key& key() const noexcept {
221 return pair_.first;
222 }
223
224 /// Returns a reference to the value.
225 Value& value() noexcept {
226 return pair_.second;
227 }
228
229 /// Returns a reference to the value.
230 const Value& value() const noexcept {
231 return pair_.second;
232 }
233
234 /// Returns a `(key, value)` pair.
235 const std::pair<Key, Value>& pair() const noexcept {
236 return pair_;
237 }
238
239 private:
240 /// This is stored as an std::pair because it will make Python binding a lot,
241 /// lot easier.
242 ::std::pair<Key, Value> pair_;
243};
244
245// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ OrderedDict ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
246
247template <typename Key, typename Value>
248OrderedDict<Key, Value>::OrderedDict(std::string key_description)
249 : key_description_(std::move(key_description)) {}
250
251template <typename Key, typename Value>
252OrderedDict<Key, Value>::OrderedDict(const OrderedDict& other)
253 : index_(other.index_), key_description_(other.key_description_) {
254 // Copy we have to do ourselves, because items' keys are const, so we have to
255 // re-insert the items.
256 for (const auto& item : other.items_) {
257 items_.push_back(item);
258 }
259}
260
261template <typename Key, typename Value>
262OrderedDict<Key, Value>& OrderedDict<Key, Value>::operator=(
263 const OrderedDict& other) {
264 index_ = other.index_;
265 items_.clear();
266 for (auto& item : other.items_) {
267 items_.push_back(item);
268 }
269 key_description_ = other.key_description_;
270 return *this;
271}
272
273template <typename Key, typename Value>
274OrderedDict<Key, Value>::OrderedDict(
275 std::initializer_list<Item> initializer_list)
276 : OrderedDict("Key") {
277 items_.reserve(initializer_list.size());
278 for (auto& item : initializer_list) {
279 // Copy the key here and move it into the index.
280 items_.emplace_back(item.key(), std::move(item.value()));
281 index_.emplace(std::move(item.key()), size() - 1);
282 }
283}
284
285template <typename Key, typename Value>
286typename OrderedDict<Key, Value>::Iterator OrderedDict<Key, Value>::begin() {
287 return items_.begin();
288}
289
290template <typename Key, typename Value>
291typename OrderedDict<Key, Value>::ConstIterator OrderedDict<Key, Value>::begin()
292 const {
293 return items_.begin();
294}
295
296template <typename Key, typename Value>
297typename OrderedDict<Key, Value>::Iterator OrderedDict<Key, Value>::end() {
298 return items_.end();
299}
300
301template <typename Key, typename Value>
302typename OrderedDict<Key, Value>::ConstIterator OrderedDict<Key, Value>::end()
303 const {
304 return items_.end();
305}
306
307template <typename Key, typename Value>
308typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::front() {
309 TORCH_CHECK(!items_.empty(), "Called front() on an empty OrderedDict");
310 return items_.front();
311}
312
313template <typename Key, typename Value>
314const typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::front()
315 const {
316 TORCH_CHECK(!items_.empty(), "Called front() on an empty OrderedDict");
317 return items_.front();
318}
319
320template <typename Key, typename Value>
321typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::back() {
322 TORCH_CHECK(!items_.empty(), "Called back() on an empty OrderedDict");
323 return items_.back();
324}
325
326template <typename Key, typename Value>
327const typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::back()
328 const {
329 TORCH_CHECK(!items_.empty(), "Called back() on an empty OrderedDict");
330 return items_.back();
331}
332
333template <typename Key, typename Value>
334typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::operator[](
335 size_t index) {
336 TORCH_CHECK(index < items_.size(), "Index ", index, " is out of bounds");
337 return items_[index];
338}
339
340template <typename Key, typename Value>
341const typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::
342operator[](size_t index) const {
343 TORCH_CHECK(index < items_.size(), "Index ", index, " is out of bounds");
344 return items_[index];
345}
346
347template <typename Key, typename Value>
348Value& OrderedDict<Key, Value>::operator[](const Key& key) {
349 if (auto* value = find(key)) {
350 return *value;
351 }
352 AT_ERROR(key_description_, " '", key, "' is not defined");
353}
354
355template <typename Key, typename Value>
356const Value& OrderedDict<Key, Value>::operator[](const Key& key) const {
357 if (auto* value = find(key)) {
358 return *value;
359 }
360 AT_ERROR(key_description_, " '", key, "' is not defined");
361}
362
363template <typename Key, typename Value>
364template <typename K, typename V>
365Value& OrderedDict<Key, Value>::insert(K&& key, V&& value) {
366 TORCH_CHECK(
367 index_.count(key) == 0, key_description_, " '", key, "' already defined");
368 // Copy `key` here and move it into the index.
369 items_.emplace_back(key, std::forward<V>(value));
370 index_.emplace(std::forward<K>(key), size() - 1);
371 return items_.back().value();
372}
373
374template <typename Key, typename Value>
375Value& OrderedDict<Key, Value>::insert(Key key, Value&& value) {
376 return insert<Key, Value>(std::move(key), std::move(value));
377}
378
379template <typename Key, typename Value>
380void OrderedDict<Key, Value>::update(OrderedDict&& other) {
381 reserve(size() + other.size());
382 for (auto& item : other) {
383 // We want to call `insert()` to prevent duplicate keys.
384 insert(std::move(item.key()), std::move(item.value()));
385 }
386}
387
388template <typename Key, typename Value>
389void OrderedDict<Key, Value>::update(const OrderedDict& other) {
390 reserve(size() + other.size());
391 for (auto& item : other) {
392 // We want to call `insert()` to prevent duplicate keys.
393 insert(item.key(), item.value());
394 }
395}
396
397template <typename Key, typename Value>
398Value* OrderedDict<Key, Value>::find(const Key& key) noexcept {
399 auto iterator = index_.find(key);
400 if (iterator == index_.end()) {
401 return nullptr;
402 }
403 return &items_[iterator->second].value();
404}
405
406template <typename Key, typename Value>
407const Value* OrderedDict<Key, Value>::find(const Key& key) const noexcept {
408 auto iterator = index_.find(key);
409 if (iterator == index_.end()) {
410 return nullptr;
411 }
412 return &items_[iterator->second].value();
413}
414
415template <typename Key, typename Value>
416void OrderedDict<Key, Value>::erase(const Key& key) {
417 auto it = index_.find(key);
418 TORCH_CHECK(it != index_.end(), "Key '", key, "' doesn't exist");
419
420 auto index = it->second;
421 index_.erase(it);
422 items_.erase(items_.begin() + index);
423
424 for (auto& pair : index_)
425 if (pair.second > index)
426 --pair.second;
427}
428
429template <typename Key, typename Value>
430bool OrderedDict<Key, Value>::contains(const Key& key) const noexcept {
431 return find(key) != nullptr;
432}
433
434template <typename Key, typename Value>
435void OrderedDict<Key, Value>::clear() {
436 index_.clear();
437 items_.clear();
438}
439
440template <typename Key, typename Value>
441size_t OrderedDict<Key, Value>::size() const noexcept {
442 return items_.size();
443}
444
445template <typename Key, typename Value>
446bool OrderedDict<Key, Value>::is_empty() const noexcept {
447 return items_.empty();
448}
449
450template <typename Key, typename Value>
451const std::string& OrderedDict<Key, Value>::key_description() const noexcept {
452 return key_description_;
453}
454
455template <typename Key, typename Value>
456const std::vector<typename OrderedDict<Key, Value>::Item>& OrderedDict<
457 Key,
458 Value>::items() const noexcept {
459 return items_;
460}
461
462template <typename Key, typename Value>
463::std::vector<Key> OrderedDict<Key, Value>::keys() const {
464 std::vector<Key> keys;
465 keys.reserve(size());
466 for (const auto& item : items_) {
467 keys.push_back(item.key());
468 }
469 return keys;
470}
471
472template <typename Key, typename Value>
473::std::vector<Value> OrderedDict<Key, Value>::values() const {
474 std::vector<Value> values;
475 values.reserve(size());
476 for (const auto& item : items_) {
477 values.push_back(item.value());
478 }
479 return values;
480}
481
482template <typename Key, typename Value>
483::std::vector<std::pair<Key, Value>> OrderedDict<Key, Value>::pairs() const {
484 std::vector<std::pair<Key, Value>> values;
485 values.reserve(size());
486 for (const auto& item : items_) {
487 values.push_back(item.pair());
488 }
489 return values;
490}
491
492template <typename Key, typename Value>
493void OrderedDict<Key, Value>::reserve(size_t requested_capacity) {
494 index_.reserve(requested_capacity);
495 items_.reserve(requested_capacity);
496}
497
498template <typename K, typename V>
499bool operator==(
500 const torch::OrderedDict<K, V>& a,
501 const torch::OrderedDict<K, V>& b) {
502 using Item = typename torch::OrderedDict<K, V>::Item;
503 if (a.index_ != b.index_)
504 return false;
505 if (a.items_.size() != b.items_.size())
506 return false;
507 // NOTE: There's no point in comparing keys for items_, as we already know
508 // that index is equal.
509 return std::equal(
510 a.items_.begin(),
511 a.items_.end(),
512 b.items_.begin(),
513 [](const Item& a, const Item& b) { return a.value() == b.value(); });
514}
515
516} // namespace torch
517