1#include <algorithm>
2#include <unordered_set>
3#include <vector>
4
5#include <c10/macros/Macros.h>
6#include <c10/util/Exception.h>
7#include <c10/util/irange.h>
8#include <c10/util/order_preserving_flat_hash_map.h>
9#include <gtest/gtest.h>
10
11namespace {
12
13#define ASSERT_EQUAL_PRIM(t1, t2) ASSERT_TRUE(t1 == t2);
14
15using dict_int_int =
16 ska_ordered::order_preserving_flat_hash_map<int64_t, int64_t>;
17
18dict_int_int test_dict(dict_int_int& dict) {
19 for (const auto i : c10::irange(100)) {
20 dict[i] = i + 1;
21 }
22
23 int64_t entry_i = 0;
24 for (auto entry : dict) {
25 TORCH_INTERNAL_ASSERT(
26 entry.first == entry_i && entry.second == entry_i + 1);
27 ++entry_i;
28 }
29
30 // erase a few entries by themselves
31 std::unordered_set<int64_t> erase_set = {0, 2, 9, 71};
32 for (auto erase : erase_set) {
33 dict.erase(erase);
34 }
35
36 // erase via iterators
37 auto begin = dict.begin();
38 for (const auto i : c10::irange(20)) {
39 (void)i; // Suppress unused variable warning
40 begin++;
41 }
42
43 auto end = begin;
44 for (const auto i : c10::irange(20)) {
45 (void)i; // Suppress unused variable warning
46 erase_set.insert(end->first);
47 end++;
48 }
49 dict.erase(begin, end);
50
51 std::vector<int64_t> order;
52 for (const auto i : c10::irange(100)) {
53 if (!erase_set.count(i)) {
54 order.push_back(i);
55 }
56 }
57
58 entry_i = 0;
59 for (auto entry : dict) {
60 TORCH_INTERNAL_ASSERT(order[entry_i] == entry.first);
61 TORCH_INTERNAL_ASSERT(dict[order[entry_i]] == entry.second);
62 TORCH_INTERNAL_ASSERT(entry.second == order[entry_i] + 1);
63 entry_i++;
64 }
65 TORCH_INTERNAL_ASSERT(dict.size() == order.size());
66 return dict;
67}
68
69TEST(OrderedPreservingDictTest, InsertAndDeleteBasic) {
70 dict_int_int dict;
71 test_dict(dict);
72 dict.clear();
73 test_dict(dict);
74}
75
76TEST(OrderedPreservingDictTest, InsertExistingDoesntAffectOrder) {
77 dict_int_int dict;
78 dict[0] = 1;
79 dict[1] = 2;
80
81 TORCH_INTERNAL_ASSERT(dict.begin()->first == 0);
82 dict[0] = 1;
83 TORCH_INTERNAL_ASSERT(dict.begin()->first == 0);
84 dict[0] = 2;
85 TORCH_INTERNAL_ASSERT(dict.begin()->first == 0);
86
87 dict.erase(0);
88 TORCH_INTERNAL_ASSERT(dict.begin()->first == 1);
89}
90
91TEST(OrderedPreservingDictTest, testRefType) {
92 std::shared_ptr<int64_t> t;
93 using dict_references = ska_ordered::
94 order_preserving_flat_hash_map<int64_t, std::shared_ptr<int64_t>>;
95
96 dict_references dict;
97
98 auto ptr = std::make_shared<int64_t>(1);
99 dict[1] = ptr;
100 TORCH_INTERNAL_ASSERT(ptr.use_count() == 2);
101 dict.erase(1);
102 TORCH_INTERNAL_ASSERT(ptr.use_count() == 1);
103
104 dict[2] = ptr;
105 dict.clear();
106 TORCH_INTERNAL_ASSERT(ptr.use_count() == 1);
107}
108
109TEST(OrderedPreservingDictTest, DictCollisions) {
110 struct BadHash {
111 size_t operator()(const int64_t input) {
112 return input % 2;
113 };
114 };
115
116 using bad_hash_dict =
117 ska_ordered::order_preserving_flat_hash_map<int64_t, int64_t, BadHash>;
118
119 for (auto init_dict_size : {27, 34, 41}) {
120 bad_hash_dict dict;
121 for (const auto i : c10::irange(init_dict_size)) {
122 dict[i] = i + 1;
123 }
124
125 int64_t i = 0;
126 for (const auto& entry : dict) {
127 TORCH_INTERNAL_ASSERT(entry.first == i && entry.second == i + 1);
128 ++i;
129 }
130
131 // erase a few entries;
132 std::unordered_set<int64_t> erase_set = {0, 2, 9};
133 for (auto erase : erase_set) {
134 dict.erase(erase);
135 }
136
137 // erase a few entries via iterator
138 auto begin = dict.begin();
139 for (const auto j : c10::irange(10)) {
140 (void)j; // Suppress unused variable warning
141 begin++;
142 }
143 auto end = begin;
144 for (const auto j : c10::irange(7)) {
145 (void)j; // Suppress unused variable warning
146 erase_set.insert(end->first);
147 end++;
148 }
149 dict.erase(begin, end);
150
151 std::vector<int64_t> order;
152 for (const auto j : c10::irange(init_dict_size)) {
153 if (!erase_set.count(j)) {
154 order.push_back(j);
155 }
156 }
157
158 i = 0;
159 for (auto entry : dict) {
160 TORCH_INTERNAL_ASSERT(dict[entry.first] == entry.second);
161 TORCH_INTERNAL_ASSERT(dict[entry.first] == order[i] + 1);
162 TORCH_INTERNAL_ASSERT(order[i] == entry.first);
163 i += 1;
164 }
165 TORCH_INTERNAL_ASSERT(dict.size() == order.size());
166 }
167}
168
169// Tests taken from
170// https://github.com/Tessil/ordered-map/blob/master/tests/ordered_map_tests.cpp
171
172TEST(OrderedPreservingDictTest, test_range_insert) {
173 // insert x values in vector, range insert x-15 values from vector to map,
174 // check values
175 const int nb_values = 1000;
176 std::vector<std::pair<int, int>> values;
177 for (const auto i : c10::irange(nb_values)) {
178 // NOLINTNEXTLINE(modernize-use-emplace,performance-inefficient-vector-operation)
179 values.push_back(std::make_pair(i, i + 1));
180 }
181
182 dict_int_int map = {{-1, 0}, {-2, 0}};
183 map.insert(values.begin() + 10, values.end() - 5);
184
185 TORCH_INTERNAL_ASSERT(map.size(), 987);
186
187 ASSERT_EQUAL_PRIM(map.at(-1), 0);
188
189 ASSERT_EQUAL_PRIM(map.at(-2), 0);
190
191 for (int i = 10, j = 2; i < nb_values - 5; i++, j++) {
192 ASSERT_EQUAL_PRIM(map.at(i), i + 1);
193 }
194}
195
196TEST(OrderedPreservingDictTest, test_range_erase_all) {
197 // insert x values, delete all
198 const std::size_t nb_values = 1000;
199 dict_int_int map;
200 for (const auto i : c10::irange(nb_values)) {
201 map[i] = i + 1;
202 }
203 auto it = map.erase(map.begin(), map.end());
204 ASSERT_TRUE(it == map.end());
205 ASSERT_TRUE(map.empty());
206}
207
208TEST(OrderedPreservingDictTest, test_range_erase) {
209 // insert x values, delete all with iterators except 10 first and 780 last
210 // values
211 using HMap =
212 ska_ordered::order_preserving_flat_hash_map<std::string, std::int64_t>;
213
214 const int64_t nb_values = 1000;
215 HMap map;
216 for (const auto i : c10::irange(nb_values)) {
217 map[c10::guts::to_string(i)] = i;
218 auto begin = map.begin();
219 for (int64_t j = 0; j <= i; ++j, begin++) {
220 TORCH_INTERNAL_ASSERT(begin->second == j);
221 }
222 }
223
224 auto it_first = std::next(map.begin(), 10);
225 auto it_last = std::next(map.begin(), 220);
226
227 auto it = map.erase(it_first, it_last);
228 ASSERT_EQUAL_PRIM(std::distance(it, map.end()), 780);
229 ASSERT_EQUAL_PRIM(map.size(), 790);
230 ASSERT_EQUAL_PRIM(std::distance(map.begin(), map.end()), 790);
231
232 for (auto& val : map) {
233 ASSERT_EQUAL_PRIM(map.count(val.first), 1);
234 }
235
236 // Check order
237 it = map.begin();
238 for (std::size_t i = 0; i < nb_values; i++) {
239 if (i >= 10 && i < 220) {
240 continue;
241 }
242 auto exp_it =
243 std::pair<std::string, std::int64_t>(c10::guts::to_string(i), i);
244 TORCH_INTERNAL_ASSERT(*it == exp_it);
245 ++it;
246 }
247}
248
249TEST(OrderedPreservingDictTest, test_move_constructor_empty) {
250 ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map(0);
251 ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map_move(
252 std::move(map));
253
254 // NOLINTNEXTLINE(bugprone-use-after-move)
255 TORCH_INTERNAL_ASSERT(map.empty());
256 TORCH_INTERNAL_ASSERT(map_move.empty());
257
258 // NOLINTNEXTLINE(clang-analyzer-cplusplus.Move)
259 TORCH_INTERNAL_ASSERT(map.find("") == map.end());
260 TORCH_INTERNAL_ASSERT(map_move.find("") == map_move.end());
261}
262
263TEST(OrderedPreservingDictTest, test_move_operator_empty) {
264 ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map(0);
265 ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map_move;
266 map_move = (std::move(map));
267
268 // NOLINTNEXTLINE(bugprone-use-after-move)
269 TORCH_INTERNAL_ASSERT(map.empty());
270 TORCH_INTERNAL_ASSERT(map_move.empty());
271
272 // NOLINTNEXTLINE(clang-analyzer-cplusplus.Move)
273 TORCH_INTERNAL_ASSERT(map.find("") == map.end());
274 TORCH_INTERNAL_ASSERT(map_move.find("") == map_move.end());
275}
276
277TEST(OrderedPreservingDictTest, test_reassign_moved_object_move_constructor) {
278 using HMap =
279 ska_ordered::order_preserving_flat_hash_map<std::string, std::string>;
280
281 HMap map = {{"Key1", "Value1"}, {"Key2", "Value2"}, {"Key3", "Value3"}};
282 HMap map_move(std::move(map));
283
284 ASSERT_EQUAL_PRIM(map_move.size(), 3);
285 // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
286 ASSERT_EQUAL_PRIM(map.size(), 0);
287
288 map = {{"Key4", "Value4"}, {"Key5", "Value5"}};
289 TORCH_INTERNAL_ASSERT(
290 map == (HMap({{"Key4", "Value4"}, {"Key5", "Value5"}})));
291}
292
293TEST(OrderedPreservingDictTest, test_reassign_moved_object_move_operator) {
294 using HMap =
295 ska_ordered::order_preserving_flat_hash_map<std::string, std::string>;
296
297 HMap map = {{"Key1", "Value1"}, {"Key2", "Value2"}, {"Key3", "Value3"}};
298 HMap map_move = std::move(map);
299
300 ASSERT_EQUAL_PRIM(map_move.size(), 3);
301 // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move)
302 ASSERT_EQUAL_PRIM(map.size(), 0);
303
304 map = {{"Key4", "Value4"}, {"Key5", "Value5"}};
305 TORCH_INTERNAL_ASSERT(
306 map == (HMap({{"Key4", "Value4"}, {"Key5", "Value5"}})));
307}
308
309TEST(OrderedPreservingDictTest, test_copy_constructor_and_operator) {
310 using HMap =
311 ska_ordered::order_preserving_flat_hash_map<std::string, std::string>;
312
313 const std::size_t nb_values = 100;
314 HMap map;
315 for (const auto i : c10::irange(nb_values)) {
316 map[c10::guts::to_string(i)] = c10::guts::to_string(i);
317 }
318
319 HMap map_copy = map;
320 HMap map_copy2(map);
321 HMap map_copy3;
322 map_copy3[c10::guts::to_string(0)] = c10::guts::to_string(0);
323
324 map_copy3 = map;
325
326 TORCH_INTERNAL_ASSERT(map == map_copy);
327 map.clear();
328
329 TORCH_INTERNAL_ASSERT(map_copy == map_copy2);
330 TORCH_INTERNAL_ASSERT(map_copy == map_copy3);
331}
332
333TEST(OrderedPreservingDictTest, test_copy_constructor_empty) {
334 ska_ordered::order_preserving_flat_hash_map<std::string, int> map(0);
335 ska_ordered::order_preserving_flat_hash_map<std::string, int> map_copy(map);
336
337 TORCH_INTERNAL_ASSERT(map.empty());
338 TORCH_INTERNAL_ASSERT(map_copy.empty());
339
340 TORCH_INTERNAL_ASSERT(map.find("") == map.end());
341 TORCH_INTERNAL_ASSERT(map_copy.find("") == map_copy.end());
342}
343
344TEST(OrderedPreservingDictTest, test_copy_operator_empty) {
345 ska_ordered::order_preserving_flat_hash_map<std::string, int> map(0);
346 ska_ordered::order_preserving_flat_hash_map<std::string, int> map_copy(16);
347 map_copy = map;
348
349 TORCH_INTERNAL_ASSERT(map.empty());
350 TORCH_INTERNAL_ASSERT(map_copy.empty());
351
352 TORCH_INTERNAL_ASSERT(map.find("") == map.end());
353 TORCH_INTERNAL_ASSERT(map_copy.find("") == map_copy.end());
354}
355
356/**
357 * at
358 */
359TEST(OrderedPreservingDictTest, test_at) {
360 // insert x values, use at for known and unknown values.
361 const ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>
362 map = {{0, 10}, {-2, 20}};
363
364 ASSERT_EQUAL_PRIM(map.at(0), 10);
365 ASSERT_EQUAL_PRIM(map.at(-2), 20);
366 bool thrown = false;
367 try {
368 map.at(1);
369 } catch (...) {
370 thrown = true;
371 }
372 ASSERT_TRUE(thrown);
373}
374
375/**
376 * equal_range
377 */
378TEST(OrderedPreservingDictTest, test_equal_range) {
379 ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map =
380 {{0, 10}, {-2, 20}};
381
382 auto it_pair = map.equal_range(0);
383 ASSERT_EQUAL_PRIM(std::distance(it_pair.first, it_pair.second), 1);
384 ASSERT_EQUAL_PRIM(it_pair.first->second, 10);
385
386 it_pair = map.equal_range(1);
387 TORCH_INTERNAL_ASSERT(it_pair.first == it_pair.second);
388 TORCH_INTERNAL_ASSERT(it_pair.first == map.end());
389}
390
391/**
392 * operator[]
393 */
394TEST(OrderedPreservingDictTest, test_access_operator) {
395 // insert x values, use at for known and unknown values.
396 ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map =
397 {{0, 10}, {-2, 20}};
398
399 ASSERT_EQUAL_PRIM(map[0], 10);
400 ASSERT_EQUAL_PRIM(map[-2], 20);
401 ASSERT_EQUAL_PRIM(map[2], std::int64_t());
402
403 ASSERT_EQUAL_PRIM(map.size(), 3);
404}
405
406/**
407 * swap
408 */
409TEST(OrderedPreservingDictTest, test_swap) {
410 ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map =
411 {{1, 10}, {8, 80}, {3, 30}};
412 ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map2 =
413 {{4, 40}, {5, 50}};
414
415 using std::swap;
416 swap(map, map2);
417
418 TORCH_INTERNAL_ASSERT(
419 map ==
420 (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{
421 {4, 40}, {5, 50}}));
422 TORCH_INTERNAL_ASSERT(
423 map2 ==
424 (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{
425 {1, 10}, {8, 80}, {3, 30}}));
426
427 map.insert({6, 60});
428 map2.insert({4, 40});
429
430 TORCH_INTERNAL_ASSERT(
431 map ==
432 (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{
433 {4, 40}, {5, 50}, {6, 60}}));
434 TORCH_INTERNAL_ASSERT(
435 map2 ==
436 (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{
437 {1, 10}, {8, 80}, {3, 30}, {4, 40}}));
438}
439
440TEST(OrderedPreservingDictTest, test_swap_empty) {
441 ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map =
442 {{1, 10}, {8, 80}, {3, 30}};
443 ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map2;
444
445 using std::swap;
446 swap(map, map2);
447
448 TORCH_INTERNAL_ASSERT(
449 map ==
450 (ska_ordered::
451 order_preserving_flat_hash_map<std::int64_t, std::int64_t>{}));
452 TORCH_INTERNAL_ASSERT(
453 map2 ==
454 (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{
455 {1, 10}, {8, 80}, {3, 30}}));
456
457 map.insert({6, 60});
458 map2.insert({4, 40});
459
460 TORCH_INTERNAL_ASSERT(
461 map ==
462 (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{
463 {6, 60}}));
464 TORCH_INTERNAL_ASSERT(
465 map2 ==
466 (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{
467 {1, 10}, {8, 80}, {3, 30}, {4, 40}}));
468}
469
470} // namespace
471