1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * \file tvm/node/structural_equal.h
21 * \brief Structural hash class.
22 */
23#ifndef TVM_NODE_STRUCTURAL_HASH_H_
24#define TVM_NODE_STRUCTURAL_HASH_H_
25
26#include <tvm/node/functor.h>
27#include <tvm/runtime/data_type.h>
28#include <tvm/runtime/ndarray.h>
29
30#include <functional>
31#include <string>
32
33namespace tvm {
34
35/*!
36 * \brief Hash definition of base value classes.
37 */
38class BaseValueHash {
39 public:
40 size_t operator()(const double& key) const { return std::hash<double>()(key); }
41
42 size_t operator()(const int64_t& key) const { return std::hash<int64_t>()(key); }
43
44 size_t operator()(const uint64_t& key) const { return std::hash<uint64_t>()(key); }
45
46 size_t operator()(const int& key) const { return std::hash<int>()(key); }
47
48 size_t operator()(const bool& key) const { return std::hash<bool>()(key); }
49
50 size_t operator()(const std::string& key) const { return std::hash<std::string>()(key); }
51
52 size_t operator()(const runtime::DataType& key) const {
53 return std::hash<int32_t>()(static_cast<int32_t>(key.code()) |
54 (static_cast<int32_t>(key.bits()) << 8) |
55 (static_cast<int32_t>(key.lanes()) << 16));
56 }
57
58 template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
59 bool operator()(const ENum& key) const {
60 return std::hash<size_t>()(static_cast<size_t>(key));
61 }
62};
63
64/*!
65 * \brief Content-aware structural hasing.
66 *
67 * The structural hash value is recursively defined in the DAG of IRNodes.
68 * There are two kinds of nodes:
69 *
70 * - Normal node: the hash value is defined by its content and type only.
71 * - Graph node: each graph node will be assigned a unique index ordered by the
72 * first occurence during the visit. The hash value of a graph node is
73 * combined from the hash values of its contents and the index.
74 */
75class StructuralHash : public BaseValueHash {
76 public:
77 // inheritate operator()
78 using BaseValueHash::operator();
79 /*!
80 * \brief Compute structural hashing value for an object.
81 * \param key The left operand.
82 * \return The hash value.
83 */
84 TVM_DLL size_t operator()(const ObjectRef& key) const;
85};
86
87/*!
88 * \brief A Reducer class to reduce the structural hash value.
89 *
90 * The reducer will call the SEqualHash function of each objects recursively.
91 *
92 * A SEqualHash function will make a sequence of calls to the reducer to
93 * indicate a sequence of child hash values that the reducer need to combine
94 * inorder to obtain the hash value of the hash value of the parent object.
95 *
96 * Importantly, the reducer may not directly use recursive calls
97 * to compute the hash values of child objects directly.
98 *
99 * Instead, it can store the necessary hash computing task into a stack
100 * and reduce the result later.
101 */
102class SHashReducer {
103 public:
104 /*! \brief Internal handler that defines custom behaviors. */
105 class Handler {
106 public:
107 /*!
108 * \brief Append hashed_value to the current sequence of hashes.
109 *
110 * \param hashed_value The hashed value
111 */
112 virtual void SHashReduceHashedValue(size_t hashed_value) = 0;
113 /*!
114 * \brief Append hash value of key to the current sequence of hashes.
115 *
116 * \param key The object to compute hash from.
117 * \param map_free_vars Whether to map free variables by their occurence number.
118 */
119 virtual void SHashReduce(const ObjectRef& key, bool map_free_vars) = 0;
120 /*!
121 * \brief Apppend a hash value of free variable to the current sequence of hashes.
122 *
123 * \param var The var of interest.
124 * \param map_free_vars Whether to map free variables by their occurence number.
125 *
126 * \note If map_free_vars is set to be true,
127 * internally the handler can maintain a counter to encode free variables
128 * by their order of occurence. This helps to resolve variable
129 * mapping of function parameters and let binding variables.
130 *
131 * If map_free_vars is set to be false, the address of the variable will be used.
132 */
133 virtual void SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) = 0;
134 /*!
135 * \brief Lookup a hash value for key
136 *
137 * \param key The hash key.
138 * \param hashed_value the result hash value
139 *
140 * \return Whether there is already a pre-computed hash value.
141 */
142 virtual bool LookupHashedValue(const ObjectRef& key, size_t* hashed_value) = 0;
143 /*!
144 * \brief Mark current comparison as graph node in hashing.
145 * Graph node hash will depends on the graph structure.
146 */
147 virtual void MarkGraphNode() = 0;
148 };
149
150 /*! \brief default constructor */
151 SHashReducer() = default;
152 /*!
153 * \brief Constructor with a specific handler.
154 * \param handler The equal handler for objects.
155 * \param map_free_vars Whether to map free variables.
156 */
157 explicit SHashReducer(Handler* handler, bool map_free_vars)
158 : handler_(handler), map_free_vars_(map_free_vars) {}
159 /*!
160 * \brief Push hash of key to the current sequence of hash values.
161 * \param key The key to be hashed.
162 */
163 template <typename T,
164 typename = typename std::enable_if<!std::is_base_of<ObjectRef, T>::value>::type>
165 void operator()(const T& key) const {
166 // handle normal values.
167 handler_->SHashReduceHashedValue(BaseValueHash()(key));
168 }
169 /*!
170 * \brief Push hash of key to the current sequence of hash values.
171 * \param key The key to be hashed.
172 */
173 void operator()(const ObjectRef& key) const { return handler_->SHashReduce(key, map_free_vars_); }
174 /*!
175 * \brief Push hash of key to the current sequence of hash values.
176 * \param key The key to be hashed.
177 * \note This function indicate key could contain var defintions.
178 */
179 void DefHash(const ObjectRef& key) const { return handler_->SHashReduce(key, true); }
180 /*!
181 * \brief Implementation for hash for a free var.
182 * \param var The variable.
183 * \return the result.
184 */
185 void FreeVarHashImpl(const runtime::Object* var) const {
186 handler_->SHashReduceFreeVar(var, map_free_vars_);
187 }
188
189 /*! \return Get the internal handler. */
190 Handler* operator->() const { return handler_; }
191
192 private:
193 /*! \brief Internal class pointer. */
194 Handler* handler_;
195 /*!
196 * \brief Whether or not to map free variables by their occurence
197 * If the flag is false, then free variables will be mapped
198 * by their in-memory address.
199 */
200 bool map_free_vars_;
201};
202
203/*! \brief The default handler for hash key computation
204 *
205 * Users can derive from this class and override the DispatchSHash method,
206 * to customize hashing.
207 */
208class SHashHandlerDefault : public SHashReducer::Handler {
209 public:
210 SHashHandlerDefault();
211 virtual ~SHashHandlerDefault();
212
213 void SHashReduceHashedValue(size_t hashed_value) override;
214 void SHashReduce(const ObjectRef& key, bool map_free_vars) override;
215 void SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) override;
216 bool LookupHashedValue(const ObjectRef& key, size_t* hashed_value) override;
217 void MarkGraphNode() override;
218
219 /*!
220 * \brief The entry point for hashing
221 * \param object The object to be hashed.
222 * \param map_free_vars Whether or not to remap variables if possible.
223 * \return The hash result.
224 */
225 virtual size_t Hash(const ObjectRef& object, bool map_free_vars);
226
227 protected:
228 /*!
229 * \brief The dispatcher for hashing of intermediate objects
230 * \param object An intermediate object to be hashed.
231 * \param map_free_vars Whether or not to remap variables if possible.
232 * \return The hash result.
233 */
234 virtual void DispatchSHash(const ObjectRef& object, bool map_free_vars);
235
236 private:
237 class Impl;
238 Impl* impl;
239};
240
241class SEqualReducer;
242struct NDArrayContainerTrait {
243 static constexpr const std::nullptr_t VisitAttrs = nullptr;
244 static void SHashReduce(const runtime::NDArray::Container* key, SHashReducer hash_reduce);
245 static bool SEqualReduce(const runtime::NDArray::Container* lhs,
246 const runtime::NDArray::Container* rhs, SEqualReducer equal);
247};
248
249} // namespace tvm
250#endif // TVM_NODE_STRUCTURAL_HASH_H_
251