1 | #include <gtest/gtest.h> |
2 | #include <string> |
3 | |
4 | #include <c10/util/int128.h> |
5 | #include <torch/csrc/lazy/core/hash.h> |
6 | |
7 | namespace torch { |
8 | namespace lazy { |
9 | |
10 | template <typename T> |
11 | void test_hash_repeatable_sensitive(const T& example_a, const T& example_b) { |
12 | // repeatable |
13 | EXPECT_EQ(Hash(example_a), Hash(example_a)); |
14 | EXPECT_EQ(MHash(example_a), MHash(example_a)); |
15 | EXPECT_EQ(MHash(example_a, example_a), MHash(example_a, example_a)); |
16 | |
17 | // sensitive |
18 | EXPECT_NE(Hash(example_a), Hash(example_b)); |
19 | EXPECT_NE(MHash(example_a), MHash(example_b)); |
20 | EXPECT_NE(MHash(example_a, example_a), MHash(example_a, example_b)); |
21 | } |
22 | |
23 | TEST(HashTest, Scalar) { |
24 | c10::Scalar a(0); |
25 | c10::Scalar b(0); |
26 | |
27 | // simulate some garbage in the unused bits of the |
28 | // the tagged union that is c10::Scalar, which is bigger |
29 | // than the size of the int64_t we're currently using it with |
30 | *((uint8_t*)&b) = 1; |
31 | // actual 'value' of the Scalar as a 64 bit int shouldn't have changed |
32 | EXPECT_EQ(a.toLong(), b.toLong()); |
33 | // and hash should ignore this garbage |
34 | EXPECT_EQ(Hash(a), Hash(b)); |
35 | EXPECT_EQ(MHash(a), MHash(b)); |
36 | EXPECT_EQ(MHash(a, a), MHash(a, b)); |
37 | } |
38 | |
39 | TEST(HashTest, Sanity) { |
40 | // String |
41 | test_hash_repeatable_sensitive( |
42 | std::string( |
43 | "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Ut at suscipit purus." ), |
44 | std::string( |
45 | "Lorem Jpsum dolor sit amet, consectetur adipiscing elit. Ut at suscipit purus." )); |
46 | |
47 | // Number types |
48 | test_hash_repeatable_sensitive(true, false); |
49 | test_hash_repeatable_sensitive((int8_t)0xfa, (int8_t)0xfb); |
50 | test_hash_repeatable_sensitive((int16_t)0xface, (int16_t)0xfade); |
51 | test_hash_repeatable_sensitive((int32_t)0xfaceb000, (int32_t)0xfadeb000); |
52 | test_hash_repeatable_sensitive((int64_t)0x1faceb000, (int64_t)0x1fadeb000); |
53 | test_hash_repeatable_sensitive((uint8_t)0xfa, (uint8_t)0xfb); |
54 | test_hash_repeatable_sensitive((uint16_t)0xface, (uint16_t)0xfade); |
55 | test_hash_repeatable_sensitive((uint32_t)0xfaceb000, (uint32_t)0xfadeb000); |
56 | test_hash_repeatable_sensitive((uint64_t)0x1faceb000, (uint64_t)0x1fadeb000); |
57 | |
58 | // c10 types |
59 | test_hash_repeatable_sensitive(c10::ScalarType::Bool, c10::ScalarType::Byte); |
60 | test_hash_repeatable_sensitive(c10::Scalar(1.334), c10::Scalar(1.335)); |
61 | test_hash_repeatable_sensitive(c10::Scalar(true), c10::Scalar(false)); |
62 | test_hash_repeatable_sensitive(c10::Scalar(12345), c10::Scalar(12354)); |
63 | |
64 | // c10::optional |
65 | test_hash_repeatable_sensitive( |
66 | c10::optional<std::string>("I have value!" ), |
67 | c10::optional<std::string>(c10::nullopt)); |
68 | |
69 | // Containers |
70 | auto a = std::vector<int32_t>({0, 1, 1, 2, 3, 5, 8}); |
71 | auto b = std::vector<int32_t>({1, 1, 2, 3, 5, 8, 12}); |
72 | test_hash_repeatable_sensitive(a, b); |
73 | test_hash_repeatable_sensitive( |
74 | c10::ArrayRef<int32_t>(a), c10::ArrayRef<int32_t>(b)); |
75 | |
76 | // vector<bool> is a special case bc it is implemented as vector<bit> |
77 | auto bool_a = std::vector<bool>({true, false, false, true}); |
78 | auto bool_b = std::vector<bool>({true, true, false, true}); |
79 | test_hash_repeatable_sensitive(bool_a, bool_b); |
80 | } |
81 | |
82 | } // namespace lazy |
83 | } // namespace torch |
84 | |