1#include <gtest/gtest.h>
2#include <string>
3
4#include <c10/util/int128.h>
5#include <torch/csrc/lazy/core/hash.h>
6
7namespace torch {
8namespace lazy {
9
10template <typename T>
11void test_hash_repeatable_sensitive(const T& example_a, const T& example_b) {
12 // repeatable
13 EXPECT_EQ(Hash(example_a), Hash(example_a));
14 EXPECT_EQ(MHash(example_a), MHash(example_a));
15 EXPECT_EQ(MHash(example_a, example_a), MHash(example_a, example_a));
16
17 // sensitive
18 EXPECT_NE(Hash(example_a), Hash(example_b));
19 EXPECT_NE(MHash(example_a), MHash(example_b));
20 EXPECT_NE(MHash(example_a, example_a), MHash(example_a, example_b));
21}
22
23TEST(HashTest, Scalar) {
24 c10::Scalar a(0);
25 c10::Scalar b(0);
26
27 // simulate some garbage in the unused bits of the
28 // the tagged union that is c10::Scalar, which is bigger
29 // than the size of the int64_t we're currently using it with
30 *((uint8_t*)&b) = 1;
31 // actual 'value' of the Scalar as a 64 bit int shouldn't have changed
32 EXPECT_EQ(a.toLong(), b.toLong());
33 // and hash should ignore this garbage
34 EXPECT_EQ(Hash(a), Hash(b));
35 EXPECT_EQ(MHash(a), MHash(b));
36 EXPECT_EQ(MHash(a, a), MHash(a, b));
37}
38
39TEST(HashTest, Sanity) {
40 // String
41 test_hash_repeatable_sensitive(
42 std::string(
43 "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Ut at suscipit purus."),
44 std::string(
45 "Lorem Jpsum dolor sit amet, consectetur adipiscing elit. Ut at suscipit purus."));
46
47 // Number types
48 test_hash_repeatable_sensitive(true, false);
49 test_hash_repeatable_sensitive((int8_t)0xfa, (int8_t)0xfb);
50 test_hash_repeatable_sensitive((int16_t)0xface, (int16_t)0xfade);
51 test_hash_repeatable_sensitive((int32_t)0xfaceb000, (int32_t)0xfadeb000);
52 test_hash_repeatable_sensitive((int64_t)0x1faceb000, (int64_t)0x1fadeb000);
53 test_hash_repeatable_sensitive((uint8_t)0xfa, (uint8_t)0xfb);
54 test_hash_repeatable_sensitive((uint16_t)0xface, (uint16_t)0xfade);
55 test_hash_repeatable_sensitive((uint32_t)0xfaceb000, (uint32_t)0xfadeb000);
56 test_hash_repeatable_sensitive((uint64_t)0x1faceb000, (uint64_t)0x1fadeb000);
57
58 // c10 types
59 test_hash_repeatable_sensitive(c10::ScalarType::Bool, c10::ScalarType::Byte);
60 test_hash_repeatable_sensitive(c10::Scalar(1.334), c10::Scalar(1.335));
61 test_hash_repeatable_sensitive(c10::Scalar(true), c10::Scalar(false));
62 test_hash_repeatable_sensitive(c10::Scalar(12345), c10::Scalar(12354));
63
64 // c10::optional
65 test_hash_repeatable_sensitive(
66 c10::optional<std::string>("I have value!"),
67 c10::optional<std::string>(c10::nullopt));
68
69 // Containers
70 auto a = std::vector<int32_t>({0, 1, 1, 2, 3, 5, 8});
71 auto b = std::vector<int32_t>({1, 1, 2, 3, 5, 8, 12});
72 test_hash_repeatable_sensitive(a, b);
73 test_hash_repeatable_sensitive(
74 c10::ArrayRef<int32_t>(a), c10::ArrayRef<int32_t>(b));
75
76 // vector<bool> is a special case bc it is implemented as vector<bit>
77 auto bool_a = std::vector<bool>({true, false, false, true});
78 auto bool_b = std::vector<bool>({true, true, false, true});
79 test_hash_repeatable_sensitive(bool_a, bool_b);
80}
81
82} // namespace lazy
83} // namespace torch
84