1 | #include <algorithm> |
2 | #include <unordered_set> |
3 | #include <vector> |
4 | |
5 | #include <c10/macros/Macros.h> |
6 | #include <c10/util/Exception.h> |
7 | #include <c10/util/irange.h> |
8 | #include <c10/util/order_preserving_flat_hash_map.h> |
9 | #include <gtest/gtest.h> |
10 | |
11 | namespace { |
12 | |
13 | #define ASSERT_EQUAL_PRIM(t1, t2) ASSERT_TRUE(t1 == t2); |
14 | |
15 | using dict_int_int = |
16 | ska_ordered::order_preserving_flat_hash_map<int64_t, int64_t>; |
17 | |
18 | dict_int_int test_dict(dict_int_int& dict) { |
19 | for (const auto i : c10::irange(100)) { |
20 | dict[i] = i + 1; |
21 | } |
22 | |
23 | int64_t entry_i = 0; |
24 | for (auto entry : dict) { |
25 | TORCH_INTERNAL_ASSERT( |
26 | entry.first == entry_i && entry.second == entry_i + 1); |
27 | ++entry_i; |
28 | } |
29 | |
30 | // erase a few entries by themselves |
31 | std::unordered_set<int64_t> erase_set = {0, 2, 9, 71}; |
32 | for (auto erase : erase_set) { |
33 | dict.erase(erase); |
34 | } |
35 | |
36 | // erase via iterators |
37 | auto begin = dict.begin(); |
38 | for (const auto i : c10::irange(20)) { |
39 | (void)i; // Suppress unused variable warning |
40 | begin++; |
41 | } |
42 | |
43 | auto end = begin; |
44 | for (const auto i : c10::irange(20)) { |
45 | (void)i; // Suppress unused variable warning |
46 | erase_set.insert(end->first); |
47 | end++; |
48 | } |
49 | dict.erase(begin, end); |
50 | |
51 | std::vector<int64_t> order; |
52 | for (const auto i : c10::irange(100)) { |
53 | if (!erase_set.count(i)) { |
54 | order.push_back(i); |
55 | } |
56 | } |
57 | |
58 | entry_i = 0; |
59 | for (auto entry : dict) { |
60 | TORCH_INTERNAL_ASSERT(order[entry_i] == entry.first); |
61 | TORCH_INTERNAL_ASSERT(dict[order[entry_i]] == entry.second); |
62 | TORCH_INTERNAL_ASSERT(entry.second == order[entry_i] + 1); |
63 | entry_i++; |
64 | } |
65 | TORCH_INTERNAL_ASSERT(dict.size() == order.size()); |
66 | return dict; |
67 | } |
68 | |
69 | TEST(OrderedPreservingDictTest, InsertAndDeleteBasic) { |
70 | dict_int_int dict; |
71 | test_dict(dict); |
72 | dict.clear(); |
73 | test_dict(dict); |
74 | } |
75 | |
76 | TEST(OrderedPreservingDictTest, InsertExistingDoesntAffectOrder) { |
77 | dict_int_int dict; |
78 | dict[0] = 1; |
79 | dict[1] = 2; |
80 | |
81 | TORCH_INTERNAL_ASSERT(dict.begin()->first == 0); |
82 | dict[0] = 1; |
83 | TORCH_INTERNAL_ASSERT(dict.begin()->first == 0); |
84 | dict[0] = 2; |
85 | TORCH_INTERNAL_ASSERT(dict.begin()->first == 0); |
86 | |
87 | dict.erase(0); |
88 | TORCH_INTERNAL_ASSERT(dict.begin()->first == 1); |
89 | } |
90 | |
91 | TEST(OrderedPreservingDictTest, testRefType) { |
92 | std::shared_ptr<int64_t> t; |
93 | using dict_references = ska_ordered:: |
94 | order_preserving_flat_hash_map<int64_t, std::shared_ptr<int64_t>>; |
95 | |
96 | dict_references dict; |
97 | |
98 | auto ptr = std::make_shared<int64_t>(1); |
99 | dict[1] = ptr; |
100 | TORCH_INTERNAL_ASSERT(ptr.use_count() == 2); |
101 | dict.erase(1); |
102 | TORCH_INTERNAL_ASSERT(ptr.use_count() == 1); |
103 | |
104 | dict[2] = ptr; |
105 | dict.clear(); |
106 | TORCH_INTERNAL_ASSERT(ptr.use_count() == 1); |
107 | } |
108 | |
109 | TEST(OrderedPreservingDictTest, DictCollisions) { |
110 | struct BadHash { |
111 | size_t operator()(const int64_t input) { |
112 | return input % 2; |
113 | }; |
114 | }; |
115 | |
116 | using bad_hash_dict = |
117 | ska_ordered::order_preserving_flat_hash_map<int64_t, int64_t, BadHash>; |
118 | |
119 | for (auto init_dict_size : {27, 34, 41}) { |
120 | bad_hash_dict dict; |
121 | for (const auto i : c10::irange(init_dict_size)) { |
122 | dict[i] = i + 1; |
123 | } |
124 | |
125 | int64_t i = 0; |
126 | for (const auto& entry : dict) { |
127 | TORCH_INTERNAL_ASSERT(entry.first == i && entry.second == i + 1); |
128 | ++i; |
129 | } |
130 | |
131 | // erase a few entries; |
132 | std::unordered_set<int64_t> erase_set = {0, 2, 9}; |
133 | for (auto erase : erase_set) { |
134 | dict.erase(erase); |
135 | } |
136 | |
137 | // erase a few entries via iterator |
138 | auto begin = dict.begin(); |
139 | for (const auto j : c10::irange(10)) { |
140 | (void)j; // Suppress unused variable warning |
141 | begin++; |
142 | } |
143 | auto end = begin; |
144 | for (const auto j : c10::irange(7)) { |
145 | (void)j; // Suppress unused variable warning |
146 | erase_set.insert(end->first); |
147 | end++; |
148 | } |
149 | dict.erase(begin, end); |
150 | |
151 | std::vector<int64_t> order; |
152 | for (const auto j : c10::irange(init_dict_size)) { |
153 | if (!erase_set.count(j)) { |
154 | order.push_back(j); |
155 | } |
156 | } |
157 | |
158 | i = 0; |
159 | for (auto entry : dict) { |
160 | TORCH_INTERNAL_ASSERT(dict[entry.first] == entry.second); |
161 | TORCH_INTERNAL_ASSERT(dict[entry.first] == order[i] + 1); |
162 | TORCH_INTERNAL_ASSERT(order[i] == entry.first); |
163 | i += 1; |
164 | } |
165 | TORCH_INTERNAL_ASSERT(dict.size() == order.size()); |
166 | } |
167 | } |
168 | |
169 | // Tests taken from |
170 | // https://github.com/Tessil/ordered-map/blob/master/tests/ordered_map_tests.cpp |
171 | |
172 | TEST(OrderedPreservingDictTest, test_range_insert) { |
173 | // insert x values in vector, range insert x-15 values from vector to map, |
174 | // check values |
175 | const int nb_values = 1000; |
176 | std::vector<std::pair<int, int>> values; |
177 | for (const auto i : c10::irange(nb_values)) { |
178 | // NOLINTNEXTLINE(modernize-use-emplace,performance-inefficient-vector-operation) |
179 | values.push_back(std::make_pair(i, i + 1)); |
180 | } |
181 | |
182 | dict_int_int map = {{-1, 0}, {-2, 0}}; |
183 | map.insert(values.begin() + 10, values.end() - 5); |
184 | |
185 | TORCH_INTERNAL_ASSERT(map.size(), 987); |
186 | |
187 | ASSERT_EQUAL_PRIM(map.at(-1), 0); |
188 | |
189 | ASSERT_EQUAL_PRIM(map.at(-2), 0); |
190 | |
191 | for (int i = 10, j = 2; i < nb_values - 5; i++, j++) { |
192 | ASSERT_EQUAL_PRIM(map.at(i), i + 1); |
193 | } |
194 | } |
195 | |
196 | TEST(OrderedPreservingDictTest, test_range_erase_all) { |
197 | // insert x values, delete all |
198 | const std::size_t nb_values = 1000; |
199 | dict_int_int map; |
200 | for (const auto i : c10::irange(nb_values)) { |
201 | map[i] = i + 1; |
202 | } |
203 | auto it = map.erase(map.begin(), map.end()); |
204 | ASSERT_TRUE(it == map.end()); |
205 | ASSERT_TRUE(map.empty()); |
206 | } |
207 | |
208 | TEST(OrderedPreservingDictTest, test_range_erase) { |
209 | // insert x values, delete all with iterators except 10 first and 780 last |
210 | // values |
211 | using HMap = |
212 | ska_ordered::order_preserving_flat_hash_map<std::string, std::int64_t>; |
213 | |
214 | const int64_t nb_values = 1000; |
215 | HMap map; |
216 | for (const auto i : c10::irange(nb_values)) { |
217 | map[c10::guts::to_string(i)] = i; |
218 | auto begin = map.begin(); |
219 | for (int64_t j = 0; j <= i; ++j, begin++) { |
220 | TORCH_INTERNAL_ASSERT(begin->second == j); |
221 | } |
222 | } |
223 | |
224 | auto it_first = std::next(map.begin(), 10); |
225 | auto it_last = std::next(map.begin(), 220); |
226 | |
227 | auto it = map.erase(it_first, it_last); |
228 | ASSERT_EQUAL_PRIM(std::distance(it, map.end()), 780); |
229 | ASSERT_EQUAL_PRIM(map.size(), 790); |
230 | ASSERT_EQUAL_PRIM(std::distance(map.begin(), map.end()), 790); |
231 | |
232 | for (auto& val : map) { |
233 | ASSERT_EQUAL_PRIM(map.count(val.first), 1); |
234 | } |
235 | |
236 | // Check order |
237 | it = map.begin(); |
238 | for (std::size_t i = 0; i < nb_values; i++) { |
239 | if (i >= 10 && i < 220) { |
240 | continue; |
241 | } |
242 | auto exp_it = |
243 | std::pair<std::string, std::int64_t>(c10::guts::to_string(i), i); |
244 | TORCH_INTERNAL_ASSERT(*it == exp_it); |
245 | ++it; |
246 | } |
247 | } |
248 | |
249 | TEST(OrderedPreservingDictTest, test_move_constructor_empty) { |
250 | ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map(0); |
251 | ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map_move( |
252 | std::move(map)); |
253 | |
254 | // NOLINTNEXTLINE(bugprone-use-after-move) |
255 | TORCH_INTERNAL_ASSERT(map.empty()); |
256 | TORCH_INTERNAL_ASSERT(map_move.empty()); |
257 | |
258 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.Move) |
259 | TORCH_INTERNAL_ASSERT(map.find("" ) == map.end()); |
260 | TORCH_INTERNAL_ASSERT(map_move.find("" ) == map_move.end()); |
261 | } |
262 | |
263 | TEST(OrderedPreservingDictTest, test_move_operator_empty) { |
264 | ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map(0); |
265 | ska_ordered::order_preserving_flat_hash_map<std::string, int64_t> map_move; |
266 | map_move = (std::move(map)); |
267 | |
268 | // NOLINTNEXTLINE(bugprone-use-after-move) |
269 | TORCH_INTERNAL_ASSERT(map.empty()); |
270 | TORCH_INTERNAL_ASSERT(map_move.empty()); |
271 | |
272 | // NOLINTNEXTLINE(clang-analyzer-cplusplus.Move) |
273 | TORCH_INTERNAL_ASSERT(map.find("" ) == map.end()); |
274 | TORCH_INTERNAL_ASSERT(map_move.find("" ) == map_move.end()); |
275 | } |
276 | |
277 | TEST(OrderedPreservingDictTest, test_reassign_moved_object_move_constructor) { |
278 | using HMap = |
279 | ska_ordered::order_preserving_flat_hash_map<std::string, std::string>; |
280 | |
281 | HMap map = {{"Key1" , "Value1" }, {"Key2" , "Value2" }, {"Key3" , "Value3" }}; |
282 | HMap map_move(std::move(map)); |
283 | |
284 | ASSERT_EQUAL_PRIM(map_move.size(), 3); |
285 | // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move) |
286 | ASSERT_EQUAL_PRIM(map.size(), 0); |
287 | |
288 | map = {{"Key4" , "Value4" }, {"Key5" , "Value5" }}; |
289 | TORCH_INTERNAL_ASSERT( |
290 | map == (HMap({{"Key4" , "Value4" }, {"Key5" , "Value5" }}))); |
291 | } |
292 | |
293 | TEST(OrderedPreservingDictTest, test_reassign_moved_object_move_operator) { |
294 | using HMap = |
295 | ska_ordered::order_preserving_flat_hash_map<std::string, std::string>; |
296 | |
297 | HMap map = {{"Key1" , "Value1" }, {"Key2" , "Value2" }, {"Key3" , "Value3" }}; |
298 | HMap map_move = std::move(map); |
299 | |
300 | ASSERT_EQUAL_PRIM(map_move.size(), 3); |
301 | // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move) |
302 | ASSERT_EQUAL_PRIM(map.size(), 0); |
303 | |
304 | map = {{"Key4" , "Value4" }, {"Key5" , "Value5" }}; |
305 | TORCH_INTERNAL_ASSERT( |
306 | map == (HMap({{"Key4" , "Value4" }, {"Key5" , "Value5" }}))); |
307 | } |
308 | |
309 | TEST(OrderedPreservingDictTest, test_copy_constructor_and_operator) { |
310 | using HMap = |
311 | ska_ordered::order_preserving_flat_hash_map<std::string, std::string>; |
312 | |
313 | const std::size_t nb_values = 100; |
314 | HMap map; |
315 | for (const auto i : c10::irange(nb_values)) { |
316 | map[c10::guts::to_string(i)] = c10::guts::to_string(i); |
317 | } |
318 | |
319 | HMap map_copy = map; |
320 | HMap map_copy2(map); |
321 | HMap map_copy3; |
322 | map_copy3[c10::guts::to_string(0)] = c10::guts::to_string(0); |
323 | |
324 | map_copy3 = map; |
325 | |
326 | TORCH_INTERNAL_ASSERT(map == map_copy); |
327 | map.clear(); |
328 | |
329 | TORCH_INTERNAL_ASSERT(map_copy == map_copy2); |
330 | TORCH_INTERNAL_ASSERT(map_copy == map_copy3); |
331 | } |
332 | |
333 | TEST(OrderedPreservingDictTest, test_copy_constructor_empty) { |
334 | ska_ordered::order_preserving_flat_hash_map<std::string, int> map(0); |
335 | ska_ordered::order_preserving_flat_hash_map<std::string, int> map_copy(map); |
336 | |
337 | TORCH_INTERNAL_ASSERT(map.empty()); |
338 | TORCH_INTERNAL_ASSERT(map_copy.empty()); |
339 | |
340 | TORCH_INTERNAL_ASSERT(map.find("" ) == map.end()); |
341 | TORCH_INTERNAL_ASSERT(map_copy.find("" ) == map_copy.end()); |
342 | } |
343 | |
344 | TEST(OrderedPreservingDictTest, test_copy_operator_empty) { |
345 | ska_ordered::order_preserving_flat_hash_map<std::string, int> map(0); |
346 | ska_ordered::order_preserving_flat_hash_map<std::string, int> map_copy(16); |
347 | map_copy = map; |
348 | |
349 | TORCH_INTERNAL_ASSERT(map.empty()); |
350 | TORCH_INTERNAL_ASSERT(map_copy.empty()); |
351 | |
352 | TORCH_INTERNAL_ASSERT(map.find("" ) == map.end()); |
353 | TORCH_INTERNAL_ASSERT(map_copy.find("" ) == map_copy.end()); |
354 | } |
355 | |
356 | /** |
357 | * at |
358 | */ |
359 | TEST(OrderedPreservingDictTest, test_at) { |
360 | // insert x values, use at for known and unknown values. |
361 | const ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> |
362 | map = {{0, 10}, {-2, 20}}; |
363 | |
364 | ASSERT_EQUAL_PRIM(map.at(0), 10); |
365 | ASSERT_EQUAL_PRIM(map.at(-2), 20); |
366 | bool thrown = false; |
367 | try { |
368 | map.at(1); |
369 | } catch (...) { |
370 | thrown = true; |
371 | } |
372 | ASSERT_TRUE(thrown); |
373 | } |
374 | |
375 | /** |
376 | * equal_range |
377 | */ |
378 | TEST(OrderedPreservingDictTest, test_equal_range) { |
379 | ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map = |
380 | {{0, 10}, {-2, 20}}; |
381 | |
382 | auto it_pair = map.equal_range(0); |
383 | ASSERT_EQUAL_PRIM(std::distance(it_pair.first, it_pair.second), 1); |
384 | ASSERT_EQUAL_PRIM(it_pair.first->second, 10); |
385 | |
386 | it_pair = map.equal_range(1); |
387 | TORCH_INTERNAL_ASSERT(it_pair.first == it_pair.second); |
388 | TORCH_INTERNAL_ASSERT(it_pair.first == map.end()); |
389 | } |
390 | |
391 | /** |
392 | * operator[] |
393 | */ |
394 | TEST(OrderedPreservingDictTest, test_access_operator) { |
395 | // insert x values, use at for known and unknown values. |
396 | ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map = |
397 | {{0, 10}, {-2, 20}}; |
398 | |
399 | ASSERT_EQUAL_PRIM(map[0], 10); |
400 | ASSERT_EQUAL_PRIM(map[-2], 20); |
401 | ASSERT_EQUAL_PRIM(map[2], std::int64_t()); |
402 | |
403 | ASSERT_EQUAL_PRIM(map.size(), 3); |
404 | } |
405 | |
406 | /** |
407 | * swap |
408 | */ |
409 | TEST(OrderedPreservingDictTest, test_swap) { |
410 | ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map = |
411 | {{1, 10}, {8, 80}, {3, 30}}; |
412 | ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map2 = |
413 | {{4, 40}, {5, 50}}; |
414 | |
415 | using std::swap; |
416 | swap(map, map2); |
417 | |
418 | TORCH_INTERNAL_ASSERT( |
419 | map == |
420 | (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{ |
421 | {4, 40}, {5, 50}})); |
422 | TORCH_INTERNAL_ASSERT( |
423 | map2 == |
424 | (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{ |
425 | {1, 10}, {8, 80}, {3, 30}})); |
426 | |
427 | map.insert({6, 60}); |
428 | map2.insert({4, 40}); |
429 | |
430 | TORCH_INTERNAL_ASSERT( |
431 | map == |
432 | (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{ |
433 | {4, 40}, {5, 50}, {6, 60}})); |
434 | TORCH_INTERNAL_ASSERT( |
435 | map2 == |
436 | (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{ |
437 | {1, 10}, {8, 80}, {3, 30}, {4, 40}})); |
438 | } |
439 | |
440 | TEST(OrderedPreservingDictTest, test_swap_empty) { |
441 | ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map = |
442 | {{1, 10}, {8, 80}, {3, 30}}; |
443 | ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t> map2; |
444 | |
445 | using std::swap; |
446 | swap(map, map2); |
447 | |
448 | TORCH_INTERNAL_ASSERT( |
449 | map == |
450 | (ska_ordered:: |
451 | order_preserving_flat_hash_map<std::int64_t, std::int64_t>{})); |
452 | TORCH_INTERNAL_ASSERT( |
453 | map2 == |
454 | (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{ |
455 | {1, 10}, {8, 80}, {3, 30}})); |
456 | |
457 | map.insert({6, 60}); |
458 | map2.insert({4, 40}); |
459 | |
460 | TORCH_INTERNAL_ASSERT( |
461 | map == |
462 | (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{ |
463 | {6, 60}})); |
464 | TORCH_INTERNAL_ASSERT( |
465 | map2 == |
466 | (ska_ordered::order_preserving_flat_hash_map<std::int64_t, std::int64_t>{ |
467 | {1, 10}, {8, 80}, {3, 30}, {4, 40}})); |
468 | } |
469 | |
470 | } // namespace |
471 | |