1/**
2 * This file is adapted from PyTorch/XLA
3 * https://github.com/pytorch/xla/blob/e0e5f937a0ba8d904f9608137dc8c51ba439df2d/third_party/xla_client/util.h
4 */
5#include <iomanip>
6#include <sstream>
7
8#include <torch/csrc/lazy/core/hash.h>
9
10namespace torch {
11namespace lazy {
12namespace {
13
14hash_t LoadHash(const uint8_t** data, const uint8_t* top) {
15 std::ptrdiff_t size = top - (*data);
16 if (size >= (int)sizeof(hash_t)) {
17 hash_t v;
18 std::memcpy(&v, *data, sizeof(v));
19 *data += sizeof(hash_t);
20 return v;
21 }
22 union {
23 hash_t h;
24 std::array<uint8_t, sizeof(hash_t)> b;
25#ifdef _MSC_VER
26 // MSVC (or some versions we use) doesn't support C99 union field init
27 // but it initializes the first member of the union.
28 } uval = {hash_t(0)};
29#else
30 } uval = {.h = hash_t(0)};
31#endif
32 // use memcpy for compatibility with platforms not supporting unaligned access
33 // note: compiled as single `movl` instr on x64.
34 std::memcpy(uval.b.data(), *data, size);
35 *data += size;
36 return uval.h;
37}
38
39} // namespace
40
41hash_t HashBlock(const void* data, size_t n, const hash_t& seed) {
42 const hash_t m(static_cast<uint64_t>(0xc6a4a7935bd1e995));
43 const int r = 47;
44
45 const uint8_t* u8_data = reinterpret_cast<const uint8_t*>(data);
46 const uint8_t* top = u8_data + n;
47 hash_t h(seed ^ ((uint64_t)n * m));
48 while (u8_data < top) {
49 hash_t k = LoadHash(&u8_data, top);
50 k *= m;
51 k ^= k >> r;
52 k *= m;
53
54 h ^= k;
55 h *= m;
56 }
57 h ^= h >> r;
58 h *= m;
59 h ^= h >> r;
60 return h;
61}
62
63hash_t DataHash(const void* data, size_t size) {
64 return HashBlock(
65 data, size, hash_t(static_cast<uint64_t>(0xc2b2ae3d27d4eb4f)));
66}
67
68size_t StdDataHash(const void* data, size_t size) {
69 return HashReduce(DataHash(data, size));
70}
71
72size_t StdHashCombine(uintmax_t a, uintmax_t b) {
73 return a ^
74 (b * 0x27d4eb2f165667c5 + 0x9e3779b97f4a7c15 + (a << 6) + (a >> 2));
75}
76
77hash_t HashCombine(const hash_t& a, const hash_t& b) {
78 static const hash_t kb(101, 0x27d4eb2f165667c5);
79 return hash_t(
80 a ^ (b * kb + (uint64_t)0x9e3779b97f4a7c15 + (a << 6) + (a >> 2)));
81}
82
83size_t HashReduce(const hash_t& a) {
84 return StdHashCombine(c10::Uint128Low64(a), c10::Uint128High64(a));
85}
86
87std::string HashToString(const hash_t& a) {
88 std::stringstream ss;
89 ss << std::hex << c10::Uint128High64(a) << std::setfill('0') << std::setw(16)
90 << Uint128Low64(a);
91 return ss.str();
92}
93
94hash_t Hash(const std::vector<bool>& values) {
95 // We can't assume a DataHash size/dataptr approach here bc
96 // vector<bool> can be optimized as vector<bit> and storage details
97 // are decoupled from actual size of 'bool' type
98 hash_t h(static_cast<uint64_t>(0xad2ed1983bbf2e28));
99 static const hash_t h_true(static_cast<uint64_t>(0x74f6b5198daa2b2));
100 static const hash_t h_false(static_cast<uint64_t>(0xe39f30789cab5382));
101 for (const auto& b : values) {
102 if (b) {
103 h = HashCombine(h, h_true);
104 } else {
105 h = HashCombine(h, h_false);
106 }
107 }
108 return h;
109}
110
111} // namespace lazy
112} // namespace torch
113