1/**
2 * Hash utils in this file is adapted from PyTorch/XLA
3 * https://github.com/pytorch/xla/blob/e0e5f937a0ba8d904f9608137dc8c51ba439df2d/third_party/xla_client/util.h
4 */
5#pragma once
6
7#include <ATen/Tensor.h>
8#include <c10/core/Scalar.h>
9#include <c10/util/int128.h>
10#include <torch/csrc/Export.h>
11#include <cstring>
12#include <set>
13#include <string>
14#include <vector>
15
16namespace torch {
17namespace lazy {
18
19using size_t = std::size_t;
20
21class TORCH_API hash_t : public c10::uint128 {
22 public:
23 // Swich from typedef hash_t = uint128 to provide explicit casters
24 hash_t(int8_t val) : uint128(static_cast<uint32_t>(val)) {}
25 hash_t(int16_t val) : uint128(static_cast<uint32_t>(val)) {}
26 hash_t(int32_t val) : uint128(static_cast<uint32_t>(val)) {}
27 hash_t(int64_t val) : uint128(static_cast<uint64_t>(val)) {}
28 hash_t(uint32_t val) : uint128(val) {}
29 hash_t(uint64_t val) : uint128(val) {}
30 hash_t(uint128 val) : uint128(val) {}
31 hash_t(uint64_t top, uint64_t bottom) : uint128(top, bottom) {}
32 hash_t() : uint128() {}
33};
34
35// Std* functions use 64-bit hash
36size_t TORCH_API StdDataHash(const void* data, size_t size);
37
38size_t TORCH_API StdHashCombine(uintmax_t a, uintmax_t b);
39
40// Other functions are all 128-bit
41hash_t TORCH_API HashBlock(const void* data, size_t n, const hash_t& seed);
42
43hash_t TORCH_API DataHash(const void* data, size_t size);
44
45hash_t TORCH_API HashCombine(const hash_t& a, const hash_t& b);
46
47size_t TORCH_API HashReduce(const hash_t& a);
48
49// Returns a string representation of a hash
50std::string TORCH_API HashToString(const hash_t& a);
51
52struct HashReducer {
53 size_t operator()(const hash_t& value) const {
54 return HashReduce(value);
55 }
56};
57
58static inline hash_t StringHash(const char* data) {
59 return DataHash(data, std::strlen(data));
60}
61
62// Automatic templated implementation for 'arithmetic' types
63template <
64 typename T,
65 typename std::enable_if<std::is_arithmetic<T>::value>::type* = nullptr>
66hash_t Hash(const T& value) {
67 return DataHash(&value, sizeof(value));
68}
69
70// added because on macos builds the vector<bool> specialization
71// breaks falling through to the templated arithmetic types above
72hash_t TORCH_API Hash(const std::vector<bool>& value);
73
74// Specialiazed implementations for proprietary types
75static inline hash_t Hash(const c10::ScalarType& value) {
76 return DataHash(&value, sizeof(value));
77}
78
79static inline hash_t Hash(const c10::MemoryFormat& value) {
80 return DataHash(&value, sizeof(value));
81}
82
83static inline hash_t Hash(const c10::DeviceType& value) {
84 return DataHash(&value, sizeof(value));
85}
86
87static inline hash_t Hash(const c10::Device& value) {
88 return HashCombine(Hash(value.type()), Hash(value.index()));
89}
90
91static inline hash_t Hash(const c10::Layout& value) {
92 return DataHash(&value, sizeof(value));
93}
94
95static inline hash_t Hash(const c10::Scalar& value) {
96 switch (value.type()) {
97 case c10::ScalarType::ComplexDouble:
98 return Hash(value.toComplexDouble());
99 case c10::ScalarType::Double:
100 return Hash(value.toDouble());
101 case c10::ScalarType::Long:
102 return Hash(value.toLong());
103 case c10::ScalarType::Bool:
104 return Hash(value.toBool());
105 default:
106 TORCH_INTERNAL_ASSERT(false, "Unknown scalar type.", value.type());
107 }
108}
109
110static inline hash_t TensorHash(const at::Tensor& tensor) {
111 at::Tensor ctensor = tensor.contiguous();
112 int64_t size = ctensor.numel() * ctensor.element_size();
113 switch (ctensor.scalar_type()) {
114 case at::ScalarType::Bool:
115 return DataHash(ctensor.data_ptr<bool>(), size);
116 case at::ScalarType::Byte:
117 return DataHash(ctensor.data_ptr<uint8_t>(), size);
118 case at::ScalarType::Char:
119 return DataHash(ctensor.data_ptr<int8_t>(), size);
120 case at::ScalarType::Short:
121 return DataHash(ctensor.data_ptr<int16_t>(), size);
122 case at::ScalarType::Int:
123 return DataHash(ctensor.data_ptr<int32_t>(), size);
124 case at::ScalarType::Long:
125 return DataHash(ctensor.data_ptr<int64_t>(), size);
126 case at::ScalarType::Float:
127 return DataHash(ctensor.data_ptr<float>(), size);
128 case at::ScalarType::Double:
129 return DataHash(ctensor.data_ptr<double>(), size);
130 case at::ScalarType::BFloat16:
131 return DataHash(ctensor.data_ptr<at::BFloat16>(), size);
132 case at::ScalarType::Half:
133 return DataHash(ctensor.data_ptr<at::Half>(), size);
134 case at::ScalarType::ComplexFloat:
135 return DataHash(ctensor.data_ptr<c10::complex<float>>(), size);
136 case at::ScalarType::ComplexDouble:
137 return DataHash(ctensor.data_ptr<c10::complex<double>>(), size);
138 default:
139 TORCH_INTERNAL_ASSERT(
140 false, "Unsupported scalar type:", ctensor.scalar_type());
141 }
142}
143
144static inline hash_t Hash(const std::string& value) {
145 return DataHash(value.data(), value.size());
146}
147
148static inline hash_t Hash(const c10::string_view& value) {
149 return DataHash(value.data(), value.size());
150}
151// Taken from glibc's implementation of hashing optionals,
152// we want to include a contribution to the hash to distinguish
153// cases where one or another option was null, but we hope it doesn't
154// collide with an actually scalar value.
155//
156// Use an arbitrary randomly-selected 64-bit integer rather than a
157// small constant that we then hash at runtime so we don't have to
158// repeatedly hash a constant at runtime.
159static const int64_t kNullOpt = 0x8655d738f3678dda;
160
161// Hashing for c10::optional types contributes to hash
162// for optionals with null value, important to distinguish
163// between <nullopt, non-nullopt> and <non-nullopt, nullopt> cases
164template <typename T>
165hash_t Hash(const c10::optional<T>& value) {
166 if (value.has_value()) {
167 return Hash(value.value());
168 } else {
169 return kNullOpt;
170 }
171}
172
173// Hashing of containers
174// Forward declare to allow hashes of vectors of vectors to work.
175template <typename T>
176hash_t ContainerHash(const T& values);
177
178template <typename T>
179hash_t Hash(const std::vector<T>& values) {
180 return ContainerHash(values);
181}
182
183// Need a special case for optional<container>?
184template <typename T>
185hash_t Hash(const c10::optional<std::vector<T>>& value) {
186 if (value.has_value()) {
187 return ContainerHash(value.value());
188 } else {
189 return kNullOpt;
190 }
191}
192
193template <typename T>
194hash_t Hash(const std::set<T>& values) {
195 return ContainerHash(values);
196}
197
198template <typename T, typename S>
199hash_t Hash(const std::pair<T, S>& values) {
200 return HashCombine(Hash(values.first), Hash(values.second));
201}
202
203static inline hash_t Hash(const hash_t& value) {
204 return value;
205}
206
207template <typename T>
208hash_t Hash(c10::ArrayRef<T> values) {
209 return ContainerHash(values);
210}
211
212template <typename T>
213hash_t ContainerHash(const T& values) {
214 hash_t h(static_cast<uint64_t>(0x85ebca77c2b2ae63));
215 for (const auto& value : values) {
216 h = HashCombine(h, Hash(value));
217 }
218 return h;
219}
220
221// Varargs hashing
222template <typename T = void>
223hash_t MHash() {
224 return hash_t(static_cast<uint64_t>(0x165667b19e3779f9));
225}
226
227template <typename T, typename... Targs>
228hash_t MHash(T value, Targs... Fargs) {
229 return HashCombine(Hash(value), MHash(Fargs...));
230}
231
232} // namespace lazy
233} // namespace torch
234