1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * \file tvm/node/structural_equal.h
21 * \brief Structural equality comparison.
22 */
23#ifndef TVM_NODE_STRUCTURAL_EQUAL_H_
24#define TVM_NODE_STRUCTURAL_EQUAL_H_
25
26#include <tvm/node/functor.h>
27#include <tvm/node/object_path.h>
28#include <tvm/runtime/container/array.h>
29#include <tvm/runtime/data_type.h>
30
31#include <string>
32
33namespace tvm {
34
35/*!
36 * \brief Equality definition of base value class.
37 */
38class BaseValueEqual {
39 public:
40 bool operator()(const double& lhs, const double& rhs) const {
41 // fuzzy float pt comparison
42 constexpr double atol = 1e-9;
43 if (lhs == rhs) return true;
44 double diff = lhs - rhs;
45 return diff > -atol && diff < atol;
46 }
47
48 bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; }
49 bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { return lhs == rhs; }
50 bool operator()(const int& lhs, const int& rhs) const { return lhs == rhs; }
51 bool operator()(const bool& lhs, const bool& rhs) const { return lhs == rhs; }
52 bool operator()(const std::string& lhs, const std::string& rhs) const { return lhs == rhs; }
53 bool operator()(const DataType& lhs, const DataType& rhs) const { return lhs == rhs; }
54 template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
55 bool operator()(const ENum& lhs, const ENum& rhs) const {
56 return lhs == rhs;
57 }
58};
59
60/*!
61 * \brief Pair of `ObjectPath`s, one for each object being tested for structural equality.
62 */
63class ObjectPathPairNode : public Object {
64 public:
65 ObjectPath lhs_path;
66 ObjectPath rhs_path;
67
68 ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path);
69
70 static constexpr const char* _type_key = "ObjectPathPair";
71 TVM_DECLARE_FINAL_OBJECT_INFO(ObjectPathPairNode, Object);
72};
73
74class ObjectPathPair : public ObjectRef {
75 public:
76 ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path);
77
78 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPathPair, ObjectRef, ObjectPathPairNode);
79};
80
81/*!
82 * \brief Content-aware structural equality comparator for objects.
83 *
84 * The structural equality is recursively defined in the DAG of IR nodes via SEqual.
85 * There are two kinds of nodes:
86 *
87 * - Graph node: a graph node in lhs can only be mapped as equal to
88 * one and only one graph node in rhs.
89 * - Normal node: equality is recursively defined without the restriction
90 * of graph nodes.
91 *
92 * Vars(tir::Var, TypeVar) and non-constant relay expression nodes are graph nodes.
93 * For example, it means that `%1 = %x + %y; %1 + %1` is not structurally equal
94 * to `%1 = %x + %y; %2 = %x + %y; %1 + %2` in relay.
95 *
96 * A var-type node(e.g. tir::Var, TypeVar) can be mapped as equal to another var
97 * with the same type if one of the following condition holds:
98 *
99 * - They appear in a same definition point(e.g. function argument).
100 * - They points to the same VarNode via the same_as relation.
101 * - They appear in a same usage point, and map_free_vars is set to be True.
102 */
103class StructuralEqual : public BaseValueEqual {
104 public:
105 // inheritate operator()
106 using BaseValueEqual::operator();
107 /*!
108 * \brief Compare objects via strutural equal.
109 * \param lhs The left operand.
110 * \param rhs The right operand.
111 * \return The comparison result.
112 */
113 TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
114};
115
116/*!
117 * \brief A Reducer class to reduce the structural equality result of two objects.
118 *
119 * The reducer will call the SEqualReduce function of each objects recursively.
120 * Importantly, the reducer may not directly use recursive calls to resolve the
121 * equality checking. Instead, it can store the necessary equality conditions
122 * and check later via an internally managed stack.
123 */
124class SEqualReducer {
125 private:
126 struct PathTracingData;
127
128 public:
129 /*! \brief Internal handler that defines custom behaviors.. */
130 class Handler {
131 public:
132 /*!
133 * \brief Reduce condition to equality of lhs and rhs.
134 *
135 * \param lhs The left operand.
136 * \param rhs The right operand.
137 * \param map_free_vars Whether do we allow remap variables if possible.
138 * \param current_paths Optional paths to `lhs` and `rhs` objects, for error traceability.
139 *
140 * \return false if there is an immediate failure, true otherwise.
141 * \note This function may save the equality condition of (lhs == rhs) in an internal
142 * stack and try to resolve later.
143 */
144 virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
145 const Optional<ObjectPathPair>& current_paths) = 0;
146
147 /*!
148 * \brief Mark the comparison as failed, but don't fail immediately.
149 *
150 * This is useful for producing better error messages when comparing containers.
151 * For example, if two array sizes mismatch, it's better to mark the comparison as failed
152 * but compare array elements anyway, so that we could find the true first mismatch.
153 */
154 virtual void DeferFail(const ObjectPathPair& mismatch_paths) = 0;
155
156 /*!
157 * \brief Check if fail defferal is enabled.
158 *
159 * \return false if the fail deferral is not enabled, true otherwise.
160 */
161 virtual bool IsFailDeferralEnabled() = 0;
162
163 /*!
164 * \brief Lookup the graph node equal map for vars that are already mapped.
165 *
166 * This is an auxiliary method to check the Map<Var, Value> equality.
167 * \param lhs an lhs value.
168 *
169 * \return The corresponding rhs value if any, nullptr if not available.
170 */
171 virtual ObjectRef MapLhsToRhs(const ObjectRef& lhs) = 0;
172 /*!
173 * \brief Mark current comparison as graph node equal comparison.
174 */
175 virtual void MarkGraphNode() = 0;
176
177 protected:
178 using PathTracingData = SEqualReducer::PathTracingData;
179 };
180
181 /*! \brief default constructor */
182 SEqualReducer() = default;
183 /*!
184 * \brief Constructor with a specific handler.
185 * \param handler The equal handler for objects.
186 * \param tracing_data Optional pointer to the path tracing data.
187 * \param map_free_vars Whether or not to map free variables.
188 */
189 explicit SEqualReducer(Handler* handler, const PathTracingData* tracing_data, bool map_free_vars)
190 : handler_(handler), tracing_data_(tracing_data), map_free_vars_(map_free_vars) {}
191
192 /*!
193 * \brief Reduce condition to comparison of two attribute values.
194 * \param lhs The left operand.
195 * \param rhs The right operand.
196 * \return the immediate check result.
197 */
198 bool operator()(const double& lhs, const double& rhs) const;
199 bool operator()(const int64_t& lhs, const int64_t& rhs) const;
200 bool operator()(const uint64_t& lhs, const uint64_t& rhs) const;
201 bool operator()(const int& lhs, const int& rhs) const;
202 bool operator()(const bool& lhs, const bool& rhs) const;
203 bool operator()(const std::string& lhs, const std::string& rhs) const;
204 bool operator()(const DataType& lhs, const DataType& rhs) const;
205
206 template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
207 bool operator()(const ENum& lhs, const ENum& rhs) const {
208 using Underlying = typename std::underlying_type<ENum>::type;
209 static_assert(std::is_same<Underlying, int>::value,
210 "Enum must have `int` as the underlying type");
211 return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs);
212 }
213
214 /*!
215 * \brief Reduce condition to comparison of two objects.
216 * \param lhs The left operand.
217 * \param rhs The right operand.
218 * \return the immediate check result.
219 */
220 bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
221
222 /*!
223 * \brief Reduce condition to comparison of two objects.
224 *
225 * Like `operator()`, but with an additional `paths` parameter that specifies explicit object
226 * paths for `lhs` and `rhs`. This is useful for implementing SEqualReduce() methods for container
227 * objects like Array and Map, or other custom objects that store nested objects that are not
228 * simply attributes.
229 *
230 * Can only be called when `IsPathTracingEnabled()` is `true`.
231 *
232 * \param lhs The left operand.
233 * \param rhs The right operand.
234 * \param paths Object paths for `lhs` and `rhs`.
235 * \return the immediate check result.
236 */
237 bool operator()(const ObjectRef& lhs, const ObjectRef& rhs, const ObjectPathPair& paths) const {
238 ICHECK(IsPathTracingEnabled()) << "Path tracing must be enabled when calling this function";
239 return ObjectAttrsEqual(lhs, rhs, map_free_vars_, &paths);
240 }
241
242 /*!
243 * \brief Reduce condition to comparison of two definitions,
244 * where free vars can be mapped.
245 *
246 * Call this function to compare definition points such as function params
247 * and var in a let-binding.
248 *
249 * \param lhs The left operand.
250 * \param rhs The right operand.
251 * \return the immediate check result.
252 */
253 bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs);
254
255 /*!
256 * \brief Reduce condition to comparison of two arrays.
257 * \param lhs The left operand.
258 * \param rhs The right operand.
259 * \return the immediate check result.
260 */
261 template <typename T>
262 bool operator()(const Array<T>& lhs, const Array<T>& rhs) const {
263 if (tracing_data_ == nullptr) {
264 // quick specialization for Array to reduce amount of recursion
265 // depth as array comparison is pretty common.
266 if (lhs.size() != rhs.size()) return false;
267 for (size_t i = 0; i < lhs.size(); ++i) {
268 if (!(operator()(lhs[i], rhs[i]))) return false;
269 }
270 return true;
271 }
272
273 // If tracing is enabled, fall back to the regular path
274 const ObjectRef& lhs_obj = lhs;
275 const ObjectRef& rhs_obj = rhs;
276 return (*this)(lhs_obj, rhs_obj);
277 }
278 /*!
279 * \brief Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
280 * \param lhs The left operand.
281 * \param rhs The right operand.
282 * \return the result.
283 */
284 bool FreeVarEqualImpl(const runtime::Object* lhs, const runtime::Object* rhs) const {
285 // var need to be remapped, so it belongs to graph node.
286 handler_->MarkGraphNode();
287 // We only map free vars if they corresponds to the same address
288 // or map free_var option is set to be true.
289 return lhs == rhs || map_free_vars_;
290 }
291
292 /*! \return Get the internal handler. */
293 Handler* operator->() const { return handler_; }
294
295 /*! \brief Check if this reducer is tracing paths to the first mismatch. */
296 bool IsPathTracingEnabled() const { return tracing_data_ != nullptr; }
297
298 /*!
299 * \brief Get the paths of the currently compared objects.
300 *
301 * Can only be called when `IsPathTracingEnabled()` is true.
302 */
303 const ObjectPathPair& GetCurrentObjectPaths() const;
304
305 /*!
306 * \brief Specify the object paths of a detected mismatch.
307 *
308 * Can only be called when `IsPathTracingEnabled()` is true.
309 */
310 void RecordMismatchPaths(const ObjectPathPair& paths) const;
311
312 private:
313 bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address) const;
314
315 bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
316 const ObjectPathPair* paths) const;
317
318 static void GetPathsFromAttrAddressesAndStoreMismatch(const void* lhs_address,
319 const void* rhs_address,
320 const PathTracingData* tracing_data);
321
322 template <typename T>
323 static bool CompareAttributeValues(const T& lhs, const T& rhs,
324 const PathTracingData* tracing_data);
325
326 /*! \brief Internal class pointer. */
327 Handler* handler_ = nullptr;
328 /*! \brief Pointer to the current path tracing context, or nullptr if path tracing is disabled. */
329 const PathTracingData* tracing_data_ = nullptr;
330 /*! \brief Whether or not to map free vars. */
331 bool map_free_vars_ = false;
332};
333
334/*! \brief The default handler for equality testing.
335 *
336 * Users can derive from this class and override the DispatchSEqualReduce method,
337 * to customize equality testing.
338 */
339class SEqualHandlerDefault : public SEqualReducer::Handler {
340 public:
341 SEqualHandlerDefault(bool assert_mode, Optional<ObjectPathPair>* first_mismatch,
342 bool defer_fails);
343 virtual ~SEqualHandlerDefault();
344
345 bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
346 const Optional<ObjectPathPair>& current_paths) override;
347 void DeferFail(const ObjectPathPair& mismatch_paths) override;
348 bool IsFailDeferralEnabled() override;
349 ObjectRef MapLhsToRhs(const ObjectRef& lhs) override;
350 void MarkGraphNode() override;
351
352 /*!
353 * \brief The entry point for equality testing
354 * \param lhs The left operand.
355 * \param rhs The right operand.
356 * \param map_free_vars Whether or not to remap variables if possible.
357 * \return The equality result.
358 */
359 virtual bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars);
360
361 protected:
362 /*!
363 * \brief The dispatcher for equality testing of intermediate objects
364 * \param lhs The left operand.
365 * \param rhs The right operand.
366 * \param map_free_vars Whether or not to remap variables if possible.
367 * \param current_paths Optional paths to `lhs` and `rhs` objects, for error traceability.
368 * \return The equality result.
369 */
370 virtual bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
371 const Optional<ObjectPathPair>& current_paths);
372
373 private:
374 class Impl;
375 Impl* impl;
376};
377
378} // namespace tvm
379#endif // TVM_NODE_STRUCTURAL_EQUAL_H_
380