1 | #pragma once |
2 | |
3 | #include <cstdint> |
4 | #include <initializer_list> |
5 | #include <string> |
6 | #include <unordered_map> |
7 | #include <utility> |
8 | #include <vector> |
9 | |
10 | namespace torch { |
11 | /// An ordered dictionary implementation, akin to Python's `OrderedDict`. |
12 | template <typename Key, typename Value> |
13 | class OrderedDict { |
14 | public: |
15 | /// A (key, value) pair. |
16 | class Item; |
17 | |
18 | // The lifetime of an iterator is bound to the lifetime of the `OrderedDict`. |
19 | // Further, any `insert()` operation may invalidate all iterators |
20 | // pointing into the vector. |
21 | using Iterator = typename std::vector<Item>::iterator; |
22 | using ConstIterator = typename std::vector<Item>::const_iterator; |
23 | |
24 | /// Constructs the `OrderedDict` with a short description of the kinds of keys |
25 | /// stored in the `OrderedDict`. This description is used in error messages |
26 | /// thrown by the `OrderedDict`. |
27 | explicit OrderedDict(std::string key_description = "Key" ); |
28 | |
29 | /// Copy constructs this `OrderedDict` from `other`. |
30 | OrderedDict(const OrderedDict& other); |
31 | |
32 | /// Assigns items from `other` to this `OrderedDict`. |
33 | OrderedDict& operator=(const OrderedDict& other); |
34 | |
35 | // NB: Move works by default, because you can move-construct vectors of const |
36 | // values. I tried to make this noexcept (conditional on the move constructors |
37 | // of index_ and items_ being noexcept) but the obvious spelling didn't |
38 | // compile on Windows. |
39 | OrderedDict(OrderedDict&& other) = default; |
40 | OrderedDict& operator=(OrderedDict&& other) = default; |
41 | |
42 | ~OrderedDict() = default; |
43 | |
44 | /// Constructs a new `OrderedDict` and pre-populates it with the given |
45 | /// `Item`s. |
46 | /*implicit */ OrderedDict(std::initializer_list<Item> initializer_list); |
47 | |
48 | /// Returns the key description string the `OrderedDict` was constructed with. |
49 | const std::string& key_description() const noexcept; |
50 | |
51 | // Element Access |
52 | |
53 | /// Returns the very first item in the `OrderedDict` and throws an exception |
54 | /// if it is empty. |
55 | Item& front(); |
56 | |
57 | /// Returns the very first item in the `OrderedDict` and throws an exception |
58 | /// if it is empty. |
59 | const Item& front() const; |
60 | |
61 | /// Returns the very last item in the `OrderedDict` and throws an exception |
62 | /// if it is empty. |
63 | Item& back(); |
64 | |
65 | /// Returns the very last item in the `OrderedDict` and throws an exception |
66 | /// if it is empty. |
67 | const Item& back() const; |
68 | |
69 | /// Returns the item at the `index`-th position in the `OrderedDict`. Throws |
70 | /// an exception if the index is out of bounds. |
71 | Item& operator[](size_t index); |
72 | |
73 | /// Returns the item at the `index`-th position in the `OrderedDict`. Throws |
74 | /// an exception if the index is out of bounds. |
75 | const Item& operator[](size_t index) const; |
76 | |
77 | /// Returns the value associated with the given `key`. Throws an exception if |
78 | /// no such key is stored in the `OrderedDict`. Use `find()` for a |
79 | /// non-throwing way of accessing a value if it is present. |
80 | Value& operator[](const Key& key); |
81 | |
82 | /// Returns the value associated with the given `key`. Throws an exception if |
83 | /// no such key is stored in the `OrderedDict`. Use `find()` for a |
84 | /// non-throwing way of accessing a value if it is present. |
85 | const Value& operator[](const Key& key) const; |
86 | |
87 | // Lookup |
88 | |
89 | /// Returns a pointer to the value associated with the given key, or a |
90 | /// `nullptr` if no such key is stored in the `OrderedDict`. |
91 | Value* find(const Key& key) noexcept; |
92 | |
93 | /// Returns a pointer to the value associated with the given key, or a |
94 | /// `nullptr` if no such key is stored in the `OrderedDict`. |
95 | const Value* find(const Key& key) const noexcept; |
96 | |
97 | /// Returns true if the key is present in the `OrderedDict`. |
98 | bool contains(const Key& key) const noexcept; |
99 | |
100 | // Iterators |
101 | |
102 | /// Returns an iterator to the first item in the `OrderedDict`. Iteration is |
103 | /// ordered. |
104 | Iterator begin(); |
105 | |
106 | /// Returns an iterator to the first item in the `OrderedDict`. Iteration is |
107 | /// ordered. |
108 | ConstIterator begin() const; |
109 | |
110 | /// Returns an iterator one past the last item in the `OrderedDict`. |
111 | Iterator end(); |
112 | |
113 | /// Returns an iterator one past the last item in the `OrderedDict`. |
114 | ConstIterator end() const; |
115 | |
116 | // Capacity |
117 | |
118 | /// Returns the number of items currently stored in the `OrderedDict`. |
119 | size_t size() const noexcept; |
120 | |
121 | /// Returns true if the `OrderedDict` contains no elements. |
122 | bool is_empty() const noexcept; |
123 | |
124 | /// Resizes internal storage to fit at least `requested_capacity` items |
125 | /// without requiring reallocation. |
126 | void reserve(size_t requested_capacity); |
127 | |
128 | // Modifiers |
129 | |
130 | /// Inserts a new `(key, value)` pair into the `OrderedDict`. Throws an |
131 | /// exception if the key is already present. If insertion is successful, |
132 | /// immediately returns a reference to the inserted value. |
133 | template <typename K, typename V> |
134 | Value& insert(K&& key, V&& value); |
135 | |
136 | /// Inserts a new `(key, value)` pair into the `OrderedDict`. Throws an |
137 | /// exception if the key is already present. If insertion is successful, |
138 | /// immediately returns a reference to the inserted value. |
139 | Value& insert(Key key, Value&& value); |
140 | |
141 | /// Inserts all items from `other` into this `OrderedDict`. If any key from |
142 | /// `other` is already present in this `OrderedDict`, an exception is thrown. |
143 | void update(OrderedDict&& other); |
144 | |
145 | /// Inserts all items from `other` into this `OrderedDict`. If any key from |
146 | /// `other` is already present in this `OrderedDict`, an exception is thrown. |
147 | void update(const OrderedDict& other); |
148 | |
149 | /// Removes the item that has `key` from this `OrderedDict` if exists and if |
150 | /// it doesn't an exception is thrown. |
151 | void erase(const Key& key); |
152 | |
153 | /// Removes all items from this `OrderedDict`. |
154 | void clear(); |
155 | |
156 | // Observers |
157 | |
158 | /// Returns the items stored in the `OrderedDict`. |
159 | const std::vector<Item>& items() const noexcept; |
160 | |
161 | /// Returns a newly allocated vector and copies all keys from this |
162 | /// `OrderedDict` into the vector. |
163 | ::std::vector<Key> keys() const; |
164 | |
165 | /// Returns a newly allocated vector and copies all values from this |
166 | /// `OrderedDict` into the vector. |
167 | ::std::vector<Value> values() const; |
168 | |
169 | /// Returns a newly allocated vector and copies all keys and values from this |
170 | /// `OrderedDict` into a vector of `std::pair<Key, Value>`. |
171 | ::std::vector<std::pair<Key, Value>> pairs() const; |
172 | |
173 | /// Returns true if both dicts contain the same keys and values, in the same |
174 | /// order. |
175 | template <typename K, typename V> |
176 | friend bool operator==( |
177 | const OrderedDict<K, V>& a, |
178 | const OrderedDict<K, V>& b); |
179 | |
180 | private: |
181 | /// A mapping from a key to an index into the `items_` vector. |
182 | ::std::unordered_map<Key, size_t> index_; |
183 | |
184 | /// The items stored in the `OrderedDict`. |
185 | ::std::vector<Item> items_; |
186 | |
187 | /// A description of the keys stored in the `OrderedDict`. |
188 | ::std::string key_description_{"Key" }; |
189 | }; |
190 | |
191 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ OrderedDict::Item ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
192 | |
193 | template <typename Key, typename Value> |
194 | class OrderedDict<Key, Value>::Item { |
195 | public: |
196 | /// Constructs a new item. |
197 | Item(Key key, Value value) : pair_(std::move(key), std::move(value)) {} |
198 | |
199 | /// Returns a reference to the value. |
200 | Value& operator*() { |
201 | return value(); |
202 | } |
203 | |
204 | /// Returns a reference to the value. |
205 | const Value& operator*() const { |
206 | return value(); |
207 | } |
208 | |
209 | /// Allows access to the value using the arrow operator. |
210 | Value* operator->() { |
211 | return &value(); |
212 | } |
213 | |
214 | /// Allows access to the value using the arrow operator. |
215 | const Value* operator->() const { |
216 | return &value(); |
217 | } |
218 | |
219 | /// Returns a reference to the key. |
220 | const Key& key() const noexcept { |
221 | return pair_.first; |
222 | } |
223 | |
224 | /// Returns a reference to the value. |
225 | Value& value() noexcept { |
226 | return pair_.second; |
227 | } |
228 | |
229 | /// Returns a reference to the value. |
230 | const Value& value() const noexcept { |
231 | return pair_.second; |
232 | } |
233 | |
234 | /// Returns a `(key, value)` pair. |
235 | const std::pair<Key, Value>& pair() const noexcept { |
236 | return pair_; |
237 | } |
238 | |
239 | private: |
240 | /// This is stored as an std::pair because it will make Python binding a lot, |
241 | /// lot easier. |
242 | ::std::pair<Key, Value> pair_; |
243 | }; |
244 | |
245 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ OrderedDict ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
246 | |
247 | template <typename Key, typename Value> |
248 | OrderedDict<Key, Value>::OrderedDict(std::string key_description) |
249 | : key_description_(std::move(key_description)) {} |
250 | |
251 | template <typename Key, typename Value> |
252 | OrderedDict<Key, Value>::OrderedDict(const OrderedDict& other) |
253 | : index_(other.index_), key_description_(other.key_description_) { |
254 | // Copy we have to do ourselves, because items' keys are const, so we have to |
255 | // re-insert the items. |
256 | for (const auto& item : other.items_) { |
257 | items_.push_back(item); |
258 | } |
259 | } |
260 | |
261 | template <typename Key, typename Value> |
262 | OrderedDict<Key, Value>& OrderedDict<Key, Value>::operator=( |
263 | const OrderedDict& other) { |
264 | index_ = other.index_; |
265 | items_.clear(); |
266 | for (auto& item : other.items_) { |
267 | items_.push_back(item); |
268 | } |
269 | key_description_ = other.key_description_; |
270 | return *this; |
271 | } |
272 | |
273 | template <typename Key, typename Value> |
274 | OrderedDict<Key, Value>::OrderedDict( |
275 | std::initializer_list<Item> initializer_list) |
276 | : OrderedDict("Key" ) { |
277 | items_.reserve(initializer_list.size()); |
278 | for (auto& item : initializer_list) { |
279 | // Copy the key here and move it into the index. |
280 | items_.emplace_back(item.key(), std::move(item.value())); |
281 | index_.emplace(std::move(item.key()), size() - 1); |
282 | } |
283 | } |
284 | |
285 | template <typename Key, typename Value> |
286 | typename OrderedDict<Key, Value>::Iterator OrderedDict<Key, Value>::begin() { |
287 | return items_.begin(); |
288 | } |
289 | |
290 | template <typename Key, typename Value> |
291 | typename OrderedDict<Key, Value>::ConstIterator OrderedDict<Key, Value>::begin() |
292 | const { |
293 | return items_.begin(); |
294 | } |
295 | |
296 | template <typename Key, typename Value> |
297 | typename OrderedDict<Key, Value>::Iterator OrderedDict<Key, Value>::end() { |
298 | return items_.end(); |
299 | } |
300 | |
301 | template <typename Key, typename Value> |
302 | typename OrderedDict<Key, Value>::ConstIterator OrderedDict<Key, Value>::end() |
303 | const { |
304 | return items_.end(); |
305 | } |
306 | |
307 | template <typename Key, typename Value> |
308 | typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::front() { |
309 | TORCH_CHECK(!items_.empty(), "Called front() on an empty OrderedDict" ); |
310 | return items_.front(); |
311 | } |
312 | |
313 | template <typename Key, typename Value> |
314 | const typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::front() |
315 | const { |
316 | TORCH_CHECK(!items_.empty(), "Called front() on an empty OrderedDict" ); |
317 | return items_.front(); |
318 | } |
319 | |
320 | template <typename Key, typename Value> |
321 | typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::back() { |
322 | TORCH_CHECK(!items_.empty(), "Called back() on an empty OrderedDict" ); |
323 | return items_.back(); |
324 | } |
325 | |
326 | template <typename Key, typename Value> |
327 | const typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::back() |
328 | const { |
329 | TORCH_CHECK(!items_.empty(), "Called back() on an empty OrderedDict" ); |
330 | return items_.back(); |
331 | } |
332 | |
333 | template <typename Key, typename Value> |
334 | typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>::operator[]( |
335 | size_t index) { |
336 | TORCH_CHECK(index < items_.size(), "Index " , index, " is out of bounds" ); |
337 | return items_[index]; |
338 | } |
339 | |
340 | template <typename Key, typename Value> |
341 | const typename OrderedDict<Key, Value>::Item& OrderedDict<Key, Value>:: |
342 | operator[](size_t index) const { |
343 | TORCH_CHECK(index < items_.size(), "Index " , index, " is out of bounds" ); |
344 | return items_[index]; |
345 | } |
346 | |
347 | template <typename Key, typename Value> |
348 | Value& OrderedDict<Key, Value>::operator[](const Key& key) { |
349 | if (auto* value = find(key)) { |
350 | return *value; |
351 | } |
352 | AT_ERROR(key_description_, " '" , key, "' is not defined" ); |
353 | } |
354 | |
355 | template <typename Key, typename Value> |
356 | const Value& OrderedDict<Key, Value>::operator[](const Key& key) const { |
357 | if (auto* value = find(key)) { |
358 | return *value; |
359 | } |
360 | AT_ERROR(key_description_, " '" , key, "' is not defined" ); |
361 | } |
362 | |
363 | template <typename Key, typename Value> |
364 | template <typename K, typename V> |
365 | Value& OrderedDict<Key, Value>::insert(K&& key, V&& value) { |
366 | TORCH_CHECK( |
367 | index_.count(key) == 0, key_description_, " '" , key, "' already defined" ); |
368 | // Copy `key` here and move it into the index. |
369 | items_.emplace_back(key, std::forward<V>(value)); |
370 | index_.emplace(std::forward<K>(key), size() - 1); |
371 | return items_.back().value(); |
372 | } |
373 | |
374 | template <typename Key, typename Value> |
375 | Value& OrderedDict<Key, Value>::insert(Key key, Value&& value) { |
376 | return insert<Key, Value>(std::move(key), std::move(value)); |
377 | } |
378 | |
379 | template <typename Key, typename Value> |
380 | void OrderedDict<Key, Value>::update(OrderedDict&& other) { |
381 | reserve(size() + other.size()); |
382 | for (auto& item : other) { |
383 | // We want to call `insert()` to prevent duplicate keys. |
384 | insert(std::move(item.key()), std::move(item.value())); |
385 | } |
386 | } |
387 | |
388 | template <typename Key, typename Value> |
389 | void OrderedDict<Key, Value>::update(const OrderedDict& other) { |
390 | reserve(size() + other.size()); |
391 | for (auto& item : other) { |
392 | // We want to call `insert()` to prevent duplicate keys. |
393 | insert(item.key(), item.value()); |
394 | } |
395 | } |
396 | |
397 | template <typename Key, typename Value> |
398 | Value* OrderedDict<Key, Value>::find(const Key& key) noexcept { |
399 | auto iterator = index_.find(key); |
400 | if (iterator == index_.end()) { |
401 | return nullptr; |
402 | } |
403 | return &items_[iterator->second].value(); |
404 | } |
405 | |
406 | template <typename Key, typename Value> |
407 | const Value* OrderedDict<Key, Value>::find(const Key& key) const noexcept { |
408 | auto iterator = index_.find(key); |
409 | if (iterator == index_.end()) { |
410 | return nullptr; |
411 | } |
412 | return &items_[iterator->second].value(); |
413 | } |
414 | |
415 | template <typename Key, typename Value> |
416 | void OrderedDict<Key, Value>::erase(const Key& key) { |
417 | auto it = index_.find(key); |
418 | TORCH_CHECK(it != index_.end(), "Key '" , key, "' doesn't exist" ); |
419 | |
420 | auto index = it->second; |
421 | index_.erase(it); |
422 | items_.erase(items_.begin() + index); |
423 | |
424 | for (auto& pair : index_) |
425 | if (pair.second > index) |
426 | --pair.second; |
427 | } |
428 | |
429 | template <typename Key, typename Value> |
430 | bool OrderedDict<Key, Value>::contains(const Key& key) const noexcept { |
431 | return find(key) != nullptr; |
432 | } |
433 | |
434 | template <typename Key, typename Value> |
435 | void OrderedDict<Key, Value>::clear() { |
436 | index_.clear(); |
437 | items_.clear(); |
438 | } |
439 | |
440 | template <typename Key, typename Value> |
441 | size_t OrderedDict<Key, Value>::size() const noexcept { |
442 | return items_.size(); |
443 | } |
444 | |
445 | template <typename Key, typename Value> |
446 | bool OrderedDict<Key, Value>::is_empty() const noexcept { |
447 | return items_.empty(); |
448 | } |
449 | |
450 | template <typename Key, typename Value> |
451 | const std::string& OrderedDict<Key, Value>::key_description() const noexcept { |
452 | return key_description_; |
453 | } |
454 | |
455 | template <typename Key, typename Value> |
456 | const std::vector<typename OrderedDict<Key, Value>::Item>& OrderedDict< |
457 | Key, |
458 | Value>::items() const noexcept { |
459 | return items_; |
460 | } |
461 | |
462 | template <typename Key, typename Value> |
463 | ::std::vector<Key> OrderedDict<Key, Value>::keys() const { |
464 | std::vector<Key> keys; |
465 | keys.reserve(size()); |
466 | for (const auto& item : items_) { |
467 | keys.push_back(item.key()); |
468 | } |
469 | return keys; |
470 | } |
471 | |
472 | template <typename Key, typename Value> |
473 | ::std::vector<Value> OrderedDict<Key, Value>::values() const { |
474 | std::vector<Value> values; |
475 | values.reserve(size()); |
476 | for (const auto& item : items_) { |
477 | values.push_back(item.value()); |
478 | } |
479 | return values; |
480 | } |
481 | |
482 | template <typename Key, typename Value> |
483 | ::std::vector<std::pair<Key, Value>> OrderedDict<Key, Value>::pairs() const { |
484 | std::vector<std::pair<Key, Value>> values; |
485 | values.reserve(size()); |
486 | for (const auto& item : items_) { |
487 | values.push_back(item.pair()); |
488 | } |
489 | return values; |
490 | } |
491 | |
492 | template <typename Key, typename Value> |
493 | void OrderedDict<Key, Value>::reserve(size_t requested_capacity) { |
494 | index_.reserve(requested_capacity); |
495 | items_.reserve(requested_capacity); |
496 | } |
497 | |
498 | template <typename K, typename V> |
499 | bool operator==( |
500 | const torch::OrderedDict<K, V>& a, |
501 | const torch::OrderedDict<K, V>& b) { |
502 | using Item = typename torch::OrderedDict<K, V>::Item; |
503 | if (a.index_ != b.index_) |
504 | return false; |
505 | if (a.items_.size() != b.items_.size()) |
506 | return false; |
507 | // NOTE: There's no point in comparing keys for items_, as we already know |
508 | // that index is equal. |
509 | return std::equal( |
510 | a.items_.begin(), |
511 | a.items_.end(), |
512 | b.items_.begin(), |
513 | [](const Item& a, const Item& b) { return a.value() == b.value(); }); |
514 | } |
515 | |
516 | } // namespace torch |
517 | |