1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | /*! |
20 | * \file tvm/node/structural_equal.h |
21 | * \brief Structural hash class. |
22 | */ |
23 | #ifndef TVM_NODE_STRUCTURAL_HASH_H_ |
24 | #define TVM_NODE_STRUCTURAL_HASH_H_ |
25 | |
26 | #include <tvm/node/functor.h> |
27 | #include <tvm/runtime/data_type.h> |
28 | #include <tvm/runtime/ndarray.h> |
29 | |
30 | #include <functional> |
31 | #include <string> |
32 | |
33 | namespace tvm { |
34 | |
35 | /*! |
36 | * \brief Hash definition of base value classes. |
37 | */ |
38 | class BaseValueHash { |
39 | public: |
40 | size_t operator()(const double& key) const { return std::hash<double>()(key); } |
41 | |
42 | size_t operator()(const int64_t& key) const { return std::hash<int64_t>()(key); } |
43 | |
44 | size_t operator()(const uint64_t& key) const { return std::hash<uint64_t>()(key); } |
45 | |
46 | size_t operator()(const int& key) const { return std::hash<int>()(key); } |
47 | |
48 | size_t operator()(const bool& key) const { return std::hash<bool>()(key); } |
49 | |
50 | size_t operator()(const std::string& key) const { return std::hash<std::string>()(key); } |
51 | |
52 | size_t operator()(const runtime::DataType& key) const { |
53 | return std::hash<int32_t>()(static_cast<int32_t>(key.code()) | |
54 | (static_cast<int32_t>(key.bits()) << 8) | |
55 | (static_cast<int32_t>(key.lanes()) << 16)); |
56 | } |
57 | |
58 | template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type> |
59 | bool operator()(const ENum& key) const { |
60 | return std::hash<size_t>()(static_cast<size_t>(key)); |
61 | } |
62 | }; |
63 | |
64 | /*! |
65 | * \brief Content-aware structural hasing. |
66 | * |
67 | * The structural hash value is recursively defined in the DAG of IRNodes. |
68 | * There are two kinds of nodes: |
69 | * |
70 | * - Normal node: the hash value is defined by its content and type only. |
71 | * - Graph node: each graph node will be assigned a unique index ordered by the |
72 | * first occurence during the visit. The hash value of a graph node is |
73 | * combined from the hash values of its contents and the index. |
74 | */ |
75 | class StructuralHash : public BaseValueHash { |
76 | public: |
77 | // inheritate operator() |
78 | using BaseValueHash::operator(); |
79 | /*! |
80 | * \brief Compute structural hashing value for an object. |
81 | * \param key The left operand. |
82 | * \return The hash value. |
83 | */ |
84 | TVM_DLL size_t operator()(const ObjectRef& key) const; |
85 | }; |
86 | |
87 | /*! |
88 | * \brief A Reducer class to reduce the structural hash value. |
89 | * |
90 | * The reducer will call the SEqualHash function of each objects recursively. |
91 | * |
92 | * A SEqualHash function will make a sequence of calls to the reducer to |
93 | * indicate a sequence of child hash values that the reducer need to combine |
94 | * inorder to obtain the hash value of the hash value of the parent object. |
95 | * |
96 | * Importantly, the reducer may not directly use recursive calls |
97 | * to compute the hash values of child objects directly. |
98 | * |
99 | * Instead, it can store the necessary hash computing task into a stack |
100 | * and reduce the result later. |
101 | */ |
102 | class SHashReducer { |
103 | public: |
104 | /*! \brief Internal handler that defines custom behaviors. */ |
105 | class Handler { |
106 | public: |
107 | /*! |
108 | * \brief Append hashed_value to the current sequence of hashes. |
109 | * |
110 | * \param hashed_value The hashed value |
111 | */ |
112 | virtual void SHashReduceHashedValue(size_t hashed_value) = 0; |
113 | /*! |
114 | * \brief Append hash value of key to the current sequence of hashes. |
115 | * |
116 | * \param key The object to compute hash from. |
117 | * \param map_free_vars Whether to map free variables by their occurence number. |
118 | */ |
119 | virtual void SHashReduce(const ObjectRef& key, bool map_free_vars) = 0; |
120 | /*! |
121 | * \brief Apppend a hash value of free variable to the current sequence of hashes. |
122 | * |
123 | * \param var The var of interest. |
124 | * \param map_free_vars Whether to map free variables by their occurence number. |
125 | * |
126 | * \note If map_free_vars is set to be true, |
127 | * internally the handler can maintain a counter to encode free variables |
128 | * by their order of occurence. This helps to resolve variable |
129 | * mapping of function parameters and let binding variables. |
130 | * |
131 | * If map_free_vars is set to be false, the address of the variable will be used. |
132 | */ |
133 | virtual void SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) = 0; |
134 | /*! |
135 | * \brief Lookup a hash value for key |
136 | * |
137 | * \param key The hash key. |
138 | * \param hashed_value the result hash value |
139 | * |
140 | * \return Whether there is already a pre-computed hash value. |
141 | */ |
142 | virtual bool LookupHashedValue(const ObjectRef& key, size_t* hashed_value) = 0; |
143 | /*! |
144 | * \brief Mark current comparison as graph node in hashing. |
145 | * Graph node hash will depends on the graph structure. |
146 | */ |
147 | virtual void MarkGraphNode() = 0; |
148 | }; |
149 | |
150 | /*! \brief default constructor */ |
151 | SHashReducer() = default; |
152 | /*! |
153 | * \brief Constructor with a specific handler. |
154 | * \param handler The equal handler for objects. |
155 | * \param map_free_vars Whether to map free variables. |
156 | */ |
157 | explicit SHashReducer(Handler* handler, bool map_free_vars) |
158 | : handler_(handler), map_free_vars_(map_free_vars) {} |
159 | /*! |
160 | * \brief Push hash of key to the current sequence of hash values. |
161 | * \param key The key to be hashed. |
162 | */ |
163 | template <typename T, |
164 | typename = typename std::enable_if<!std::is_base_of<ObjectRef, T>::value>::type> |
165 | void operator()(const T& key) const { |
166 | // handle normal values. |
167 | handler_->SHashReduceHashedValue(BaseValueHash()(key)); |
168 | } |
169 | /*! |
170 | * \brief Push hash of key to the current sequence of hash values. |
171 | * \param key The key to be hashed. |
172 | */ |
173 | void operator()(const ObjectRef& key) const { return handler_->SHashReduce(key, map_free_vars_); } |
174 | /*! |
175 | * \brief Push hash of key to the current sequence of hash values. |
176 | * \param key The key to be hashed. |
177 | * \note This function indicate key could contain var defintions. |
178 | */ |
179 | void DefHash(const ObjectRef& key) const { return handler_->SHashReduce(key, true); } |
180 | /*! |
181 | * \brief Implementation for hash for a free var. |
182 | * \param var The variable. |
183 | * \return the result. |
184 | */ |
185 | void FreeVarHashImpl(const runtime::Object* var) const { |
186 | handler_->SHashReduceFreeVar(var, map_free_vars_); |
187 | } |
188 | |
189 | /*! \return Get the internal handler. */ |
190 | Handler* operator->() const { return handler_; } |
191 | |
192 | private: |
193 | /*! \brief Internal class pointer. */ |
194 | Handler* handler_; |
195 | /*! |
196 | * \brief Whether or not to map free variables by their occurence |
197 | * If the flag is false, then free variables will be mapped |
198 | * by their in-memory address. |
199 | */ |
200 | bool map_free_vars_; |
201 | }; |
202 | |
203 | /*! \brief The default handler for hash key computation |
204 | * |
205 | * Users can derive from this class and override the DispatchSHash method, |
206 | * to customize hashing. |
207 | */ |
208 | class SHashHandlerDefault : public SHashReducer::Handler { |
209 | public: |
210 | SHashHandlerDefault(); |
211 | virtual ~SHashHandlerDefault(); |
212 | |
213 | void SHashReduceHashedValue(size_t hashed_value) override; |
214 | void SHashReduce(const ObjectRef& key, bool map_free_vars) override; |
215 | void SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) override; |
216 | bool LookupHashedValue(const ObjectRef& key, size_t* hashed_value) override; |
217 | void MarkGraphNode() override; |
218 | |
219 | /*! |
220 | * \brief The entry point for hashing |
221 | * \param object The object to be hashed. |
222 | * \param map_free_vars Whether or not to remap variables if possible. |
223 | * \return The hash result. |
224 | */ |
225 | virtual size_t Hash(const ObjectRef& object, bool map_free_vars); |
226 | |
227 | protected: |
228 | /*! |
229 | * \brief The dispatcher for hashing of intermediate objects |
230 | * \param object An intermediate object to be hashed. |
231 | * \param map_free_vars Whether or not to remap variables if possible. |
232 | * \return The hash result. |
233 | */ |
234 | virtual void DispatchSHash(const ObjectRef& object, bool map_free_vars); |
235 | |
236 | private: |
237 | class Impl; |
238 | Impl* impl; |
239 | }; |
240 | |
241 | class SEqualReducer; |
242 | struct NDArrayContainerTrait { |
243 | static constexpr const std::nullptr_t VisitAttrs = nullptr; |
244 | static void SHashReduce(const runtime::NDArray::Container* key, SHashReducer hash_reduce); |
245 | static bool SEqualReduce(const runtime::NDArray::Container* lhs, |
246 | const runtime::NDArray::Container* rhs, SEqualReducer equal); |
247 | }; |
248 | |
249 | } // namespace tvm |
250 | #endif // TVM_NODE_STRUCTURAL_HASH_H_ |
251 | |