1 | #pragma once |
2 | |
3 | #include <functional> |
4 | #include <iomanip> |
5 | #include <sstream> |
6 | #include <vector> |
7 | |
8 | #include <c10/util/ArrayRef.h> |
9 | #include <c10/util/complex.h> |
10 | |
11 | namespace c10 { |
12 | |
13 | // NOTE: hash_combine and SHA1 hashing is based on implementation from Boost |
14 | // |
15 | // Boost Software License - Version 1.0 - August 17th, 2003 |
16 | // |
17 | // Permission is hereby granted, free of charge, to any person or organization |
18 | // obtaining a copy of the software and accompanying documentation covered by |
19 | // this license (the "Software") to use, reproduce, display, distribute, |
20 | // execute, and transmit the Software, and to prepare derivative works of the |
21 | // Software, and to permit third-parties to whom the Software is furnished to |
22 | // do so, all subject to the following: |
23 | // |
24 | // The copyright notices in the Software and this entire statement, including |
25 | // the above license grant, this restriction and the following disclaimer, |
26 | // must be included in all copies of the Software, in whole or in part, and |
27 | // all derivative works of the Software, unless such copies or derivative |
28 | // works are solely in the form of machine-executable object code generated by |
29 | // a source language processor. |
30 | // |
31 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
32 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
33 | // FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT |
34 | // SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE |
35 | // FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, |
36 | // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER |
37 | // DEALINGS IN THE SOFTWARE. |
38 | |
39 | inline size_t hash_combine(size_t seed, size_t value) { |
40 | return seed ^ (value + 0x9e3779b9 + (seed << 6u) + (seed >> 2u)); |
41 | } |
42 | |
43 | // Creates the SHA1 hash of a string. A 160-bit hash. |
44 | // Based on the implementation in Boost (see notice above). |
45 | // Note that SHA1 hashes are no longer considered cryptographically |
46 | // secure, but are the standard hash for generating unique ids. |
47 | // Usage: |
48 | // // Let 'code' be a std::string |
49 | // c10::sha1 sha1_hash{code}; |
50 | // const auto hash_code = sha1_hash.str(); |
51 | // TODO: Compare vs OpenSSL and/or CryptoPP implementations |
52 | struct sha1 { |
53 | typedef unsigned int(digest_type)[5]; |
54 | |
55 | sha1(const std::string& s = "" ) { |
56 | if (!s.empty()) { |
57 | reset(); |
58 | process_bytes(s.c_str(), s.size()); |
59 | } |
60 | } |
61 | |
62 | void reset() { |
63 | h_[0] = 0x67452301; |
64 | h_[1] = 0xEFCDAB89; |
65 | h_[2] = 0x98BADCFE; |
66 | h_[3] = 0x10325476; |
67 | h_[4] = 0xC3D2E1F0; |
68 | |
69 | block_byte_index_ = 0; |
70 | bit_count_low = 0; |
71 | bit_count_high = 0; |
72 | } |
73 | |
74 | std::string str() { |
75 | unsigned int digest[5]; |
76 | get_digest(digest); |
77 | |
78 | std::ostringstream buf; |
79 | for (unsigned int i : digest) { |
80 | buf << std::hex << std::setfill('0') << std::setw(8) << i; |
81 | } |
82 | |
83 | return buf.str(); |
84 | } |
85 | |
86 | private: |
87 | unsigned int left_rotate(unsigned int x, std::size_t n) { |
88 | return (x << n) ^ (x >> (32 - n)); |
89 | } |
90 | |
91 | void process_block_impl() { |
92 | unsigned int w[80]; |
93 | |
94 | for (std::size_t i = 0; i < 16; ++i) { |
95 | w[i] = (block_[i * 4 + 0] << 24); |
96 | w[i] |= (block_[i * 4 + 1] << 16); |
97 | w[i] |= (block_[i * 4 + 2] << 8); |
98 | w[i] |= (block_[i * 4 + 3]); |
99 | } |
100 | |
101 | for (std::size_t i = 16; i < 80; ++i) { |
102 | w[i] = left_rotate((w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16]), 1); |
103 | } |
104 | |
105 | unsigned int a = h_[0]; |
106 | unsigned int b = h_[1]; |
107 | unsigned int c = h_[2]; |
108 | unsigned int d = h_[3]; |
109 | unsigned int e = h_[4]; |
110 | |
111 | for (std::size_t i = 0; i < 80; ++i) { |
112 | unsigned int f; |
113 | unsigned int k; |
114 | |
115 | if (i < 20) { |
116 | f = (b & c) | (~b & d); |
117 | k = 0x5A827999; |
118 | } else if (i < 40) { |
119 | f = b ^ c ^ d; |
120 | k = 0x6ED9EBA1; |
121 | } else if (i < 60) { |
122 | f = (b & c) | (b & d) | (c & d); |
123 | k = 0x8F1BBCDC; |
124 | } else { |
125 | f = b ^ c ^ d; |
126 | k = 0xCA62C1D6; |
127 | } |
128 | |
129 | unsigned temp = left_rotate(a, 5) + f + e + k + w[i]; |
130 | e = d; |
131 | d = c; |
132 | c = left_rotate(b, 30); |
133 | b = a; |
134 | a = temp; |
135 | } |
136 | |
137 | h_[0] += a; |
138 | h_[1] += b; |
139 | h_[2] += c; |
140 | h_[3] += d; |
141 | h_[4] += e; |
142 | } |
143 | |
144 | void process_byte_impl(unsigned char byte) { |
145 | block_[block_byte_index_++] = byte; |
146 | |
147 | if (block_byte_index_ == 64) { |
148 | block_byte_index_ = 0; |
149 | process_block_impl(); |
150 | } |
151 | } |
152 | |
153 | void process_byte(unsigned char byte) { |
154 | process_byte_impl(byte); |
155 | |
156 | // size_t max value = 0xFFFFFFFF |
157 | // if (bit_count_low + 8 >= 0x100000000) { // would overflow |
158 | // if (bit_count_low >= 0x100000000-8) { |
159 | if (bit_count_low < 0xFFFFFFF8) { |
160 | bit_count_low += 8; |
161 | } else { |
162 | bit_count_low = 0; |
163 | |
164 | if (bit_count_high <= 0xFFFFFFFE) { |
165 | ++bit_count_high; |
166 | } else { |
167 | TORCH_CHECK(false, "sha1 too many bytes" ); |
168 | } |
169 | } |
170 | } |
171 | |
172 | void process_block(void const* bytes_begin, void const* bytes_end) { |
173 | unsigned char const* begin = static_cast<unsigned char const*>(bytes_begin); |
174 | unsigned char const* end = static_cast<unsigned char const*>(bytes_end); |
175 | for (; begin != end; ++begin) { |
176 | process_byte(*begin); |
177 | } |
178 | } |
179 | |
180 | void process_bytes(void const* buffer, std::size_t byte_count) { |
181 | unsigned char const* b = static_cast<unsigned char const*>(buffer); |
182 | process_block(b, b + byte_count); |
183 | } |
184 | |
185 | void get_digest(digest_type& digest) { |
186 | // append the bit '1' to the message |
187 | process_byte_impl(0x80); |
188 | |
189 | // append k bits '0', where k is the minimum number >= 0 |
190 | // such that the resulting message length is congruent to 56 (mod 64) |
191 | // check if there is enough space for padding and bit_count |
192 | if (block_byte_index_ > 56) { |
193 | // finish this block |
194 | while (block_byte_index_ != 0) { |
195 | process_byte_impl(0); |
196 | } |
197 | |
198 | // one more block |
199 | while (block_byte_index_ < 56) { |
200 | process_byte_impl(0); |
201 | } |
202 | } else { |
203 | while (block_byte_index_ < 56) { |
204 | process_byte_impl(0); |
205 | } |
206 | } |
207 | |
208 | // append length of message (before pre-processing) |
209 | // as a 64-bit big-endian integer |
210 | process_byte_impl( |
211 | static_cast<unsigned char>((bit_count_high >> 24) & 0xFF)); |
212 | process_byte_impl( |
213 | static_cast<unsigned char>((bit_count_high >> 16) & 0xFF)); |
214 | process_byte_impl(static_cast<unsigned char>((bit_count_high >> 8) & 0xFF)); |
215 | process_byte_impl(static_cast<unsigned char>((bit_count_high)&0xFF)); |
216 | process_byte_impl(static_cast<unsigned char>((bit_count_low >> 24) & 0xFF)); |
217 | process_byte_impl(static_cast<unsigned char>((bit_count_low >> 16) & 0xFF)); |
218 | process_byte_impl(static_cast<unsigned char>((bit_count_low >> 8) & 0xFF)); |
219 | process_byte_impl(static_cast<unsigned char>((bit_count_low)&0xFF)); |
220 | |
221 | // get final digest |
222 | digest[0] = h_[0]; |
223 | digest[1] = h_[1]; |
224 | digest[2] = h_[2]; |
225 | digest[3] = h_[3]; |
226 | digest[4] = h_[4]; |
227 | } |
228 | |
229 | unsigned int h_[5]; |
230 | unsigned char block_[64]; |
231 | std::size_t block_byte_index_; |
232 | std::size_t bit_count_low; |
233 | std::size_t bit_count_high; |
234 | }; |
235 | |
236 | //////////////////////////////////////////////////////////////////////////////// |
237 | // c10::hash implementation |
238 | //////////////////////////////////////////////////////////////////////////////// |
239 | |
240 | namespace _hash_detail { |
241 | |
242 | // Use template argument deduction to shorten calls to c10::hash |
243 | template <typename T> |
244 | size_t simple_get_hash(const T& o); |
245 | |
246 | template <typename T, typename V> |
247 | using type_if_not_enum = |
248 | typename std::enable_if<!std::is_enum<T>::value, V>::type; |
249 | |
250 | // Use SFINAE to dispatch to std::hash if possible, cast enum types to int |
251 | // automatically, and fall back to T::hash otherwise. NOTE: C++14 added support |
252 | // for hashing enum types to the standard, and some compilers implement it even |
253 | // when C++14 flags aren't specified. This is why we have to disable this |
254 | // overload if T is an enum type (and use the one below in this case). |
255 | template <typename T> |
256 | auto dispatch_hash(const T& o) |
257 | -> decltype(std::hash<T>()(o), type_if_not_enum<T, size_t>()) { |
258 | return std::hash<T>()(o); |
259 | } |
260 | |
261 | template <typename T> |
262 | typename std::enable_if<std::is_enum<T>::value, size_t>::type dispatch_hash( |
263 | const T& o) { |
264 | using R = typename std::underlying_type<T>::type; |
265 | return std::hash<R>()(static_cast<R>(o)); |
266 | } |
267 | |
268 | template <typename T> |
269 | auto dispatch_hash(const T& o) -> decltype(T::hash(o), size_t()) { |
270 | return T::hash(o); |
271 | } |
272 | |
273 | } // namespace _hash_detail |
274 | |
275 | // Hasher struct |
276 | template <typename T> |
277 | struct hash { |
278 | size_t operator()(const T& o) const { |
279 | return _hash_detail::dispatch_hash(o); |
280 | }; |
281 | }; |
282 | |
283 | // Specialization for std::tuple |
284 | template <typename... Types> |
285 | struct hash<std::tuple<Types...>> { |
286 | template <size_t idx, typename... Ts> |
287 | struct tuple_hash { |
288 | size_t operator()(const std::tuple<Ts...>& t) const { |
289 | return hash_combine( |
290 | _hash_detail::simple_get_hash(std::get<idx>(t)), |
291 | tuple_hash<idx - 1, Ts...>()(t)); |
292 | } |
293 | }; |
294 | |
295 | template <typename... Ts> |
296 | struct tuple_hash<0, Ts...> { |
297 | size_t operator()(const std::tuple<Ts...>& t) const { |
298 | return _hash_detail::simple_get_hash(std::get<0>(t)); |
299 | } |
300 | }; |
301 | |
302 | size_t operator()(const std::tuple<Types...>& t) const { |
303 | return tuple_hash<sizeof...(Types) - 1, Types...>()(t); |
304 | } |
305 | }; |
306 | |
307 | template <typename T1, typename T2> |
308 | struct hash<std::pair<T1, T2>> { |
309 | size_t operator()(const std::pair<T1, T2>& pair) const { |
310 | std::tuple<T1, T2> tuple = std::make_tuple(pair.first, pair.second); |
311 | return _hash_detail::simple_get_hash(tuple); |
312 | } |
313 | }; |
314 | |
315 | template <typename T> |
316 | struct hash<c10::ArrayRef<T>> { |
317 | size_t operator()(c10::ArrayRef<T> v) const { |
318 | size_t seed = 0; |
319 | for (const auto& elem : v) { |
320 | seed = hash_combine(seed, _hash_detail::simple_get_hash(elem)); |
321 | } |
322 | return seed; |
323 | } |
324 | }; |
325 | |
326 | // Specialization for std::vector |
327 | template <typename T> |
328 | struct hash<std::vector<T>> { |
329 | size_t operator()(const std::vector<T>& v) const { |
330 | return hash<c10::ArrayRef<T>>()(v); |
331 | } |
332 | }; |
333 | |
334 | namespace _hash_detail { |
335 | |
336 | template <typename T> |
337 | size_t simple_get_hash(const T& o) { |
338 | return c10::hash<T>()(o); |
339 | } |
340 | |
341 | } // namespace _hash_detail |
342 | |
343 | // Use this function to actually hash multiple things in one line. |
344 | // Dispatches to c10::hash, so it can hash containers. |
345 | // Example: |
346 | // |
347 | // static size_t hash(const MyStruct& s) { |
348 | // return get_hash(s.member1, s.member2, s.member3); |
349 | // } |
350 | template <typename... Types> |
351 | size_t get_hash(const Types&... args) { |
352 | return c10::hash<decltype(std::tie(args...))>()(std::tie(args...)); |
353 | } |
354 | |
355 | // Specialization for c10::complex |
356 | template <typename T> |
357 | struct hash<c10::complex<T>> { |
358 | size_t operator()(const c10::complex<T>& c) const { |
359 | return get_hash(c.real(), c.imag()); |
360 | } |
361 | }; |
362 | |
363 | } // namespace c10 |
364 | |