1#pragma once
2
3#include <functional>
4#include <iomanip>
5#include <sstream>
6#include <vector>
7
8#include <c10/util/ArrayRef.h>
9#include <c10/util/complex.h>
10
11namespace c10 {
12
13// NOTE: hash_combine and SHA1 hashing is based on implementation from Boost
14//
15// Boost Software License - Version 1.0 - August 17th, 2003
16//
17// Permission is hereby granted, free of charge, to any person or organization
18// obtaining a copy of the software and accompanying documentation covered by
19// this license (the "Software") to use, reproduce, display, distribute,
20// execute, and transmit the Software, and to prepare derivative works of the
21// Software, and to permit third-parties to whom the Software is furnished to
22// do so, all subject to the following:
23//
24// The copyright notices in the Software and this entire statement, including
25// the above license grant, this restriction and the following disclaimer,
26// must be included in all copies of the Software, in whole or in part, and
27// all derivative works of the Software, unless such copies or derivative
28// works are solely in the form of machine-executable object code generated by
29// a source language processor.
30//
31// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33// FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
34// SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
35// FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
36// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
37// DEALINGS IN THE SOFTWARE.
38
39inline size_t hash_combine(size_t seed, size_t value) {
40 return seed ^ (value + 0x9e3779b9 + (seed << 6u) + (seed >> 2u));
41}
42
43// Creates the SHA1 hash of a string. A 160-bit hash.
44// Based on the implementation in Boost (see notice above).
45// Note that SHA1 hashes are no longer considered cryptographically
46// secure, but are the standard hash for generating unique ids.
47// Usage:
48// // Let 'code' be a std::string
49// c10::sha1 sha1_hash{code};
50// const auto hash_code = sha1_hash.str();
51// TODO: Compare vs OpenSSL and/or CryptoPP implementations
52struct sha1 {
53 typedef unsigned int(digest_type)[5];
54
55 sha1(const std::string& s = "") {
56 if (!s.empty()) {
57 reset();
58 process_bytes(s.c_str(), s.size());
59 }
60 }
61
62 void reset() {
63 h_[0] = 0x67452301;
64 h_[1] = 0xEFCDAB89;
65 h_[2] = 0x98BADCFE;
66 h_[3] = 0x10325476;
67 h_[4] = 0xC3D2E1F0;
68
69 block_byte_index_ = 0;
70 bit_count_low = 0;
71 bit_count_high = 0;
72 }
73
74 std::string str() {
75 unsigned int digest[5];
76 get_digest(digest);
77
78 std::ostringstream buf;
79 for (unsigned int i : digest) {
80 buf << std::hex << std::setfill('0') << std::setw(8) << i;
81 }
82
83 return buf.str();
84 }
85
86 private:
87 unsigned int left_rotate(unsigned int x, std::size_t n) {
88 return (x << n) ^ (x >> (32 - n));
89 }
90
91 void process_block_impl() {
92 unsigned int w[80];
93
94 for (std::size_t i = 0; i < 16; ++i) {
95 w[i] = (block_[i * 4 + 0] << 24);
96 w[i] |= (block_[i * 4 + 1] << 16);
97 w[i] |= (block_[i * 4 + 2] << 8);
98 w[i] |= (block_[i * 4 + 3]);
99 }
100
101 for (std::size_t i = 16; i < 80; ++i) {
102 w[i] = left_rotate((w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16]), 1);
103 }
104
105 unsigned int a = h_[0];
106 unsigned int b = h_[1];
107 unsigned int c = h_[2];
108 unsigned int d = h_[3];
109 unsigned int e = h_[4];
110
111 for (std::size_t i = 0; i < 80; ++i) {
112 unsigned int f;
113 unsigned int k;
114
115 if (i < 20) {
116 f = (b & c) | (~b & d);
117 k = 0x5A827999;
118 } else if (i < 40) {
119 f = b ^ c ^ d;
120 k = 0x6ED9EBA1;
121 } else if (i < 60) {
122 f = (b & c) | (b & d) | (c & d);
123 k = 0x8F1BBCDC;
124 } else {
125 f = b ^ c ^ d;
126 k = 0xCA62C1D6;
127 }
128
129 unsigned temp = left_rotate(a, 5) + f + e + k + w[i];
130 e = d;
131 d = c;
132 c = left_rotate(b, 30);
133 b = a;
134 a = temp;
135 }
136
137 h_[0] += a;
138 h_[1] += b;
139 h_[2] += c;
140 h_[3] += d;
141 h_[4] += e;
142 }
143
144 void process_byte_impl(unsigned char byte) {
145 block_[block_byte_index_++] = byte;
146
147 if (block_byte_index_ == 64) {
148 block_byte_index_ = 0;
149 process_block_impl();
150 }
151 }
152
153 void process_byte(unsigned char byte) {
154 process_byte_impl(byte);
155
156 // size_t max value = 0xFFFFFFFF
157 // if (bit_count_low + 8 >= 0x100000000) { // would overflow
158 // if (bit_count_low >= 0x100000000-8) {
159 if (bit_count_low < 0xFFFFFFF8) {
160 bit_count_low += 8;
161 } else {
162 bit_count_low = 0;
163
164 if (bit_count_high <= 0xFFFFFFFE) {
165 ++bit_count_high;
166 } else {
167 TORCH_CHECK(false, "sha1 too many bytes");
168 }
169 }
170 }
171
172 void process_block(void const* bytes_begin, void const* bytes_end) {
173 unsigned char const* begin = static_cast<unsigned char const*>(bytes_begin);
174 unsigned char const* end = static_cast<unsigned char const*>(bytes_end);
175 for (; begin != end; ++begin) {
176 process_byte(*begin);
177 }
178 }
179
180 void process_bytes(void const* buffer, std::size_t byte_count) {
181 unsigned char const* b = static_cast<unsigned char const*>(buffer);
182 process_block(b, b + byte_count);
183 }
184
185 void get_digest(digest_type& digest) {
186 // append the bit '1' to the message
187 process_byte_impl(0x80);
188
189 // append k bits '0', where k is the minimum number >= 0
190 // such that the resulting message length is congruent to 56 (mod 64)
191 // check if there is enough space for padding and bit_count
192 if (block_byte_index_ > 56) {
193 // finish this block
194 while (block_byte_index_ != 0) {
195 process_byte_impl(0);
196 }
197
198 // one more block
199 while (block_byte_index_ < 56) {
200 process_byte_impl(0);
201 }
202 } else {
203 while (block_byte_index_ < 56) {
204 process_byte_impl(0);
205 }
206 }
207
208 // append length of message (before pre-processing)
209 // as a 64-bit big-endian integer
210 process_byte_impl(
211 static_cast<unsigned char>((bit_count_high >> 24) & 0xFF));
212 process_byte_impl(
213 static_cast<unsigned char>((bit_count_high >> 16) & 0xFF));
214 process_byte_impl(static_cast<unsigned char>((bit_count_high >> 8) & 0xFF));
215 process_byte_impl(static_cast<unsigned char>((bit_count_high)&0xFF));
216 process_byte_impl(static_cast<unsigned char>((bit_count_low >> 24) & 0xFF));
217 process_byte_impl(static_cast<unsigned char>((bit_count_low >> 16) & 0xFF));
218 process_byte_impl(static_cast<unsigned char>((bit_count_low >> 8) & 0xFF));
219 process_byte_impl(static_cast<unsigned char>((bit_count_low)&0xFF));
220
221 // get final digest
222 digest[0] = h_[0];
223 digest[1] = h_[1];
224 digest[2] = h_[2];
225 digest[3] = h_[3];
226 digest[4] = h_[4];
227 }
228
229 unsigned int h_[5];
230 unsigned char block_[64];
231 std::size_t block_byte_index_;
232 std::size_t bit_count_low;
233 std::size_t bit_count_high;
234};
235
236////////////////////////////////////////////////////////////////////////////////
237// c10::hash implementation
238////////////////////////////////////////////////////////////////////////////////
239
240namespace _hash_detail {
241
242// Use template argument deduction to shorten calls to c10::hash
243template <typename T>
244size_t simple_get_hash(const T& o);
245
246template <typename T, typename V>
247using type_if_not_enum =
248 typename std::enable_if<!std::is_enum<T>::value, V>::type;
249
250// Use SFINAE to dispatch to std::hash if possible, cast enum types to int
251// automatically, and fall back to T::hash otherwise. NOTE: C++14 added support
252// for hashing enum types to the standard, and some compilers implement it even
253// when C++14 flags aren't specified. This is why we have to disable this
254// overload if T is an enum type (and use the one below in this case).
255template <typename T>
256auto dispatch_hash(const T& o)
257 -> decltype(std::hash<T>()(o), type_if_not_enum<T, size_t>()) {
258 return std::hash<T>()(o);
259}
260
261template <typename T>
262typename std::enable_if<std::is_enum<T>::value, size_t>::type dispatch_hash(
263 const T& o) {
264 using R = typename std::underlying_type<T>::type;
265 return std::hash<R>()(static_cast<R>(o));
266}
267
268template <typename T>
269auto dispatch_hash(const T& o) -> decltype(T::hash(o), size_t()) {
270 return T::hash(o);
271}
272
273} // namespace _hash_detail
274
275// Hasher struct
276template <typename T>
277struct hash {
278 size_t operator()(const T& o) const {
279 return _hash_detail::dispatch_hash(o);
280 };
281};
282
283// Specialization for std::tuple
284template <typename... Types>
285struct hash<std::tuple<Types...>> {
286 template <size_t idx, typename... Ts>
287 struct tuple_hash {
288 size_t operator()(const std::tuple<Ts...>& t) const {
289 return hash_combine(
290 _hash_detail::simple_get_hash(std::get<idx>(t)),
291 tuple_hash<idx - 1, Ts...>()(t));
292 }
293 };
294
295 template <typename... Ts>
296 struct tuple_hash<0, Ts...> {
297 size_t operator()(const std::tuple<Ts...>& t) const {
298 return _hash_detail::simple_get_hash(std::get<0>(t));
299 }
300 };
301
302 size_t operator()(const std::tuple<Types...>& t) const {
303 return tuple_hash<sizeof...(Types) - 1, Types...>()(t);
304 }
305};
306
307template <typename T1, typename T2>
308struct hash<std::pair<T1, T2>> {
309 size_t operator()(const std::pair<T1, T2>& pair) const {
310 std::tuple<T1, T2> tuple = std::make_tuple(pair.first, pair.second);
311 return _hash_detail::simple_get_hash(tuple);
312 }
313};
314
315template <typename T>
316struct hash<c10::ArrayRef<T>> {
317 size_t operator()(c10::ArrayRef<T> v) const {
318 size_t seed = 0;
319 for (const auto& elem : v) {
320 seed = hash_combine(seed, _hash_detail::simple_get_hash(elem));
321 }
322 return seed;
323 }
324};
325
326// Specialization for std::vector
327template <typename T>
328struct hash<std::vector<T>> {
329 size_t operator()(const std::vector<T>& v) const {
330 return hash<c10::ArrayRef<T>>()(v);
331 }
332};
333
334namespace _hash_detail {
335
336template <typename T>
337size_t simple_get_hash(const T& o) {
338 return c10::hash<T>()(o);
339}
340
341} // namespace _hash_detail
342
343// Use this function to actually hash multiple things in one line.
344// Dispatches to c10::hash, so it can hash containers.
345// Example:
346//
347// static size_t hash(const MyStruct& s) {
348// return get_hash(s.member1, s.member2, s.member3);
349// }
350template <typename... Types>
351size_t get_hash(const Types&... args) {
352 return c10::hash<decltype(std::tie(args...))>()(std::tie(args...));
353}
354
355// Specialization for c10::complex
356template <typename T>
357struct hash<c10::complex<T>> {
358 size_t operator()(const c10::complex<T>& c) const {
359 return get_hash(c.real(), c.imag());
360 }
361};
362
363} // namespace c10
364