1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | /*! |
20 | * \file tvm/node/structural_equal.h |
21 | * \brief Structural equality comparison. |
22 | */ |
23 | #ifndef TVM_NODE_STRUCTURAL_EQUAL_H_ |
24 | #define TVM_NODE_STRUCTURAL_EQUAL_H_ |
25 | |
26 | #include <tvm/node/functor.h> |
27 | #include <tvm/node/object_path.h> |
28 | #include <tvm/runtime/container/array.h> |
29 | #include <tvm/runtime/data_type.h> |
30 | |
31 | #include <string> |
32 | |
33 | namespace tvm { |
34 | |
35 | /*! |
36 | * \brief Equality definition of base value class. |
37 | */ |
38 | class BaseValueEqual { |
39 | public: |
40 | bool operator()(const double& lhs, const double& rhs) const { |
41 | // fuzzy float pt comparison |
42 | constexpr double atol = 1e-9; |
43 | if (lhs == rhs) return true; |
44 | double diff = lhs - rhs; |
45 | return diff > -atol && diff < atol; |
46 | } |
47 | |
48 | bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; } |
49 | bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { return lhs == rhs; } |
50 | bool operator()(const int& lhs, const int& rhs) const { return lhs == rhs; } |
51 | bool operator()(const bool& lhs, const bool& rhs) const { return lhs == rhs; } |
52 | bool operator()(const std::string& lhs, const std::string& rhs) const { return lhs == rhs; } |
53 | bool operator()(const DataType& lhs, const DataType& rhs) const { return lhs == rhs; } |
54 | template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type> |
55 | bool operator()(const ENum& lhs, const ENum& rhs) const { |
56 | return lhs == rhs; |
57 | } |
58 | }; |
59 | |
60 | /*! |
61 | * \brief Pair of `ObjectPath`s, one for each object being tested for structural equality. |
62 | */ |
63 | class ObjectPathPairNode : public Object { |
64 | public: |
65 | ObjectPath lhs_path; |
66 | ObjectPath rhs_path; |
67 | |
68 | ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path); |
69 | |
70 | static constexpr const char* _type_key = "ObjectPathPair" ; |
71 | TVM_DECLARE_FINAL_OBJECT_INFO(ObjectPathPairNode, Object); |
72 | }; |
73 | |
74 | class ObjectPathPair : public ObjectRef { |
75 | public: |
76 | ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path); |
77 | |
78 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPathPair, ObjectRef, ObjectPathPairNode); |
79 | }; |
80 | |
81 | /*! |
82 | * \brief Content-aware structural equality comparator for objects. |
83 | * |
84 | * The structural equality is recursively defined in the DAG of IR nodes via SEqual. |
85 | * There are two kinds of nodes: |
86 | * |
87 | * - Graph node: a graph node in lhs can only be mapped as equal to |
88 | * one and only one graph node in rhs. |
89 | * - Normal node: equality is recursively defined without the restriction |
90 | * of graph nodes. |
91 | * |
92 | * Vars(tir::Var, TypeVar) and non-constant relay expression nodes are graph nodes. |
93 | * For example, it means that `%1 = %x + %y; %1 + %1` is not structurally equal |
94 | * to `%1 = %x + %y; %2 = %x + %y; %1 + %2` in relay. |
95 | * |
96 | * A var-type node(e.g. tir::Var, TypeVar) can be mapped as equal to another var |
97 | * with the same type if one of the following condition holds: |
98 | * |
99 | * - They appear in a same definition point(e.g. function argument). |
100 | * - They points to the same VarNode via the same_as relation. |
101 | * - They appear in a same usage point, and map_free_vars is set to be True. |
102 | */ |
103 | class StructuralEqual : public BaseValueEqual { |
104 | public: |
105 | // inheritate operator() |
106 | using BaseValueEqual::operator(); |
107 | /*! |
108 | * \brief Compare objects via strutural equal. |
109 | * \param lhs The left operand. |
110 | * \param rhs The right operand. |
111 | * \return The comparison result. |
112 | */ |
113 | TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const; |
114 | }; |
115 | |
116 | /*! |
117 | * \brief A Reducer class to reduce the structural equality result of two objects. |
118 | * |
119 | * The reducer will call the SEqualReduce function of each objects recursively. |
120 | * Importantly, the reducer may not directly use recursive calls to resolve the |
121 | * equality checking. Instead, it can store the necessary equality conditions |
122 | * and check later via an internally managed stack. |
123 | */ |
124 | class SEqualReducer { |
125 | private: |
126 | struct PathTracingData; |
127 | |
128 | public: |
129 | /*! \brief Internal handler that defines custom behaviors.. */ |
130 | class Handler { |
131 | public: |
132 | /*! |
133 | * \brief Reduce condition to equality of lhs and rhs. |
134 | * |
135 | * \param lhs The left operand. |
136 | * \param rhs The right operand. |
137 | * \param map_free_vars Whether do we allow remap variables if possible. |
138 | * \param current_paths Optional paths to `lhs` and `rhs` objects, for error traceability. |
139 | * |
140 | * \return false if there is an immediate failure, true otherwise. |
141 | * \note This function may save the equality condition of (lhs == rhs) in an internal |
142 | * stack and try to resolve later. |
143 | */ |
144 | virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, |
145 | const Optional<ObjectPathPair>& current_paths) = 0; |
146 | |
147 | /*! |
148 | * \brief Mark the comparison as failed, but don't fail immediately. |
149 | * |
150 | * This is useful for producing better error messages when comparing containers. |
151 | * For example, if two array sizes mismatch, it's better to mark the comparison as failed |
152 | * but compare array elements anyway, so that we could find the true first mismatch. |
153 | */ |
154 | virtual void DeferFail(const ObjectPathPair& mismatch_paths) = 0; |
155 | |
156 | /*! |
157 | * \brief Check if fail defferal is enabled. |
158 | * |
159 | * \return false if the fail deferral is not enabled, true otherwise. |
160 | */ |
161 | virtual bool IsFailDeferralEnabled() = 0; |
162 | |
163 | /*! |
164 | * \brief Lookup the graph node equal map for vars that are already mapped. |
165 | * |
166 | * This is an auxiliary method to check the Map<Var, Value> equality. |
167 | * \param lhs an lhs value. |
168 | * |
169 | * \return The corresponding rhs value if any, nullptr if not available. |
170 | */ |
171 | virtual ObjectRef MapLhsToRhs(const ObjectRef& lhs) = 0; |
172 | /*! |
173 | * \brief Mark current comparison as graph node equal comparison. |
174 | */ |
175 | virtual void MarkGraphNode() = 0; |
176 | |
177 | protected: |
178 | using PathTracingData = SEqualReducer::PathTracingData; |
179 | }; |
180 | |
181 | /*! \brief default constructor */ |
182 | SEqualReducer() = default; |
183 | /*! |
184 | * \brief Constructor with a specific handler. |
185 | * \param handler The equal handler for objects. |
186 | * \param tracing_data Optional pointer to the path tracing data. |
187 | * \param map_free_vars Whether or not to map free variables. |
188 | */ |
189 | explicit SEqualReducer(Handler* handler, const PathTracingData* tracing_data, bool map_free_vars) |
190 | : handler_(handler), tracing_data_(tracing_data), map_free_vars_(map_free_vars) {} |
191 | |
192 | /*! |
193 | * \brief Reduce condition to comparison of two attribute values. |
194 | * \param lhs The left operand. |
195 | * \param rhs The right operand. |
196 | * \return the immediate check result. |
197 | */ |
198 | bool operator()(const double& lhs, const double& rhs) const; |
199 | bool operator()(const int64_t& lhs, const int64_t& rhs) const; |
200 | bool operator()(const uint64_t& lhs, const uint64_t& rhs) const; |
201 | bool operator()(const int& lhs, const int& rhs) const; |
202 | bool operator()(const bool& lhs, const bool& rhs) const; |
203 | bool operator()(const std::string& lhs, const std::string& rhs) const; |
204 | bool operator()(const DataType& lhs, const DataType& rhs) const; |
205 | |
206 | template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type> |
207 | bool operator()(const ENum& lhs, const ENum& rhs) const { |
208 | using Underlying = typename std::underlying_type<ENum>::type; |
209 | static_assert(std::is_same<Underlying, int>::value, |
210 | "Enum must have `int` as the underlying type" ); |
211 | return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs); |
212 | } |
213 | |
214 | /*! |
215 | * \brief Reduce condition to comparison of two objects. |
216 | * \param lhs The left operand. |
217 | * \param rhs The right operand. |
218 | * \return the immediate check result. |
219 | */ |
220 | bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const; |
221 | |
222 | /*! |
223 | * \brief Reduce condition to comparison of two objects. |
224 | * |
225 | * Like `operator()`, but with an additional `paths` parameter that specifies explicit object |
226 | * paths for `lhs` and `rhs`. This is useful for implementing SEqualReduce() methods for container |
227 | * objects like Array and Map, or other custom objects that store nested objects that are not |
228 | * simply attributes. |
229 | * |
230 | * Can only be called when `IsPathTracingEnabled()` is `true`. |
231 | * |
232 | * \param lhs The left operand. |
233 | * \param rhs The right operand. |
234 | * \param paths Object paths for `lhs` and `rhs`. |
235 | * \return the immediate check result. |
236 | */ |
237 | bool operator()(const ObjectRef& lhs, const ObjectRef& rhs, const ObjectPathPair& paths) const { |
238 | ICHECK(IsPathTracingEnabled()) << "Path tracing must be enabled when calling this function" ; |
239 | return ObjectAttrsEqual(lhs, rhs, map_free_vars_, &paths); |
240 | } |
241 | |
242 | /*! |
243 | * \brief Reduce condition to comparison of two definitions, |
244 | * where free vars can be mapped. |
245 | * |
246 | * Call this function to compare definition points such as function params |
247 | * and var in a let-binding. |
248 | * |
249 | * \param lhs The left operand. |
250 | * \param rhs The right operand. |
251 | * \return the immediate check result. |
252 | */ |
253 | bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs); |
254 | |
255 | /*! |
256 | * \brief Reduce condition to comparison of two arrays. |
257 | * \param lhs The left operand. |
258 | * \param rhs The right operand. |
259 | * \return the immediate check result. |
260 | */ |
261 | template <typename T> |
262 | bool operator()(const Array<T>& lhs, const Array<T>& rhs) const { |
263 | if (tracing_data_ == nullptr) { |
264 | // quick specialization for Array to reduce amount of recursion |
265 | // depth as array comparison is pretty common. |
266 | if (lhs.size() != rhs.size()) return false; |
267 | for (size_t i = 0; i < lhs.size(); ++i) { |
268 | if (!(operator()(lhs[i], rhs[i]))) return false; |
269 | } |
270 | return true; |
271 | } |
272 | |
273 | // If tracing is enabled, fall back to the regular path |
274 | const ObjectRef& lhs_obj = lhs; |
275 | const ObjectRef& rhs_obj = rhs; |
276 | return (*this)(lhs_obj, rhs_obj); |
277 | } |
278 | /*! |
279 | * \brief Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var). |
280 | * \param lhs The left operand. |
281 | * \param rhs The right operand. |
282 | * \return the result. |
283 | */ |
284 | bool FreeVarEqualImpl(const runtime::Object* lhs, const runtime::Object* rhs) const { |
285 | // var need to be remapped, so it belongs to graph node. |
286 | handler_->MarkGraphNode(); |
287 | // We only map free vars if they corresponds to the same address |
288 | // or map free_var option is set to be true. |
289 | return lhs == rhs || map_free_vars_; |
290 | } |
291 | |
292 | /*! \return Get the internal handler. */ |
293 | Handler* operator->() const { return handler_; } |
294 | |
295 | /*! \brief Check if this reducer is tracing paths to the first mismatch. */ |
296 | bool IsPathTracingEnabled() const { return tracing_data_ != nullptr; } |
297 | |
298 | /*! |
299 | * \brief Get the paths of the currently compared objects. |
300 | * |
301 | * Can only be called when `IsPathTracingEnabled()` is true. |
302 | */ |
303 | const ObjectPathPair& GetCurrentObjectPaths() const; |
304 | |
305 | /*! |
306 | * \brief Specify the object paths of a detected mismatch. |
307 | * |
308 | * Can only be called when `IsPathTracingEnabled()` is true. |
309 | */ |
310 | void RecordMismatchPaths(const ObjectPathPair& paths) const; |
311 | |
312 | private: |
313 | bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address) const; |
314 | |
315 | bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, |
316 | const ObjectPathPair* paths) const; |
317 | |
318 | static void GetPathsFromAttrAddressesAndStoreMismatch(const void* lhs_address, |
319 | const void* rhs_address, |
320 | const PathTracingData* tracing_data); |
321 | |
322 | template <typename T> |
323 | static bool CompareAttributeValues(const T& lhs, const T& rhs, |
324 | const PathTracingData* tracing_data); |
325 | |
326 | /*! \brief Internal class pointer. */ |
327 | Handler* handler_ = nullptr; |
328 | /*! \brief Pointer to the current path tracing context, or nullptr if path tracing is disabled. */ |
329 | const PathTracingData* tracing_data_ = nullptr; |
330 | /*! \brief Whether or not to map free vars. */ |
331 | bool map_free_vars_ = false; |
332 | }; |
333 | |
334 | /*! \brief The default handler for equality testing. |
335 | * |
336 | * Users can derive from this class and override the DispatchSEqualReduce method, |
337 | * to customize equality testing. |
338 | */ |
339 | class SEqualHandlerDefault : public SEqualReducer::Handler { |
340 | public: |
341 | SEqualHandlerDefault(bool assert_mode, Optional<ObjectPathPair>* first_mismatch, |
342 | bool defer_fails); |
343 | virtual ~SEqualHandlerDefault(); |
344 | |
345 | bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, |
346 | const Optional<ObjectPathPair>& current_paths) override; |
347 | void DeferFail(const ObjectPathPair& mismatch_paths) override; |
348 | bool IsFailDeferralEnabled() override; |
349 | ObjectRef MapLhsToRhs(const ObjectRef& lhs) override; |
350 | void MarkGraphNode() override; |
351 | |
352 | /*! |
353 | * \brief The entry point for equality testing |
354 | * \param lhs The left operand. |
355 | * \param rhs The right operand. |
356 | * \param map_free_vars Whether or not to remap variables if possible. |
357 | * \return The equality result. |
358 | */ |
359 | virtual bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars); |
360 | |
361 | protected: |
362 | /*! |
363 | * \brief The dispatcher for equality testing of intermediate objects |
364 | * \param lhs The left operand. |
365 | * \param rhs The right operand. |
366 | * \param map_free_vars Whether or not to remap variables if possible. |
367 | * \param current_paths Optional paths to `lhs` and `rhs` objects, for error traceability. |
368 | * \return The equality result. |
369 | */ |
370 | virtual bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, |
371 | const Optional<ObjectPathPair>& current_paths); |
372 | |
373 | private: |
374 | class Impl; |
375 | Impl* impl; |
376 | }; |
377 | |
378 | } // namespace tvm |
379 | #endif // TVM_NODE_STRUCTURAL_EQUAL_H_ |
380 | |