1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * \file src/node/structural_equal.cc
21 */
22#include <tvm/ir/module.h>
23#include <tvm/node/functor.h>
24#include <tvm/node/node.h>
25#include <tvm/node/object_path.h>
26#include <tvm/node/reflection.h>
27#include <tvm/node/structural_equal.h>
28#include <tvm/runtime/registry.h>
29
30#include <unordered_map>
31
32#include "ndarray_hash_equal.h"
33
34namespace tvm {
35
36TVM_REGISTER_OBJECT_TYPE(ObjectPathPairNode);
37
38TVM_REGISTER_GLOBAL("node.ObjectPathPairLhsPath")
39 .set_body_typed([](const ObjectPathPair& object_path_pair) {
40 return object_path_pair->lhs_path;
41 });
42
43TVM_REGISTER_GLOBAL("node.ObjectPathPairRhsPath")
44 .set_body_typed([](const ObjectPathPair& object_path_pair) {
45 return object_path_pair->rhs_path;
46 });
47
48ObjectPathPairNode::ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path)
49 : lhs_path(std::move(lhs_path)), rhs_path(std::move(rhs_path)) {}
50
51ObjectPathPair::ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path) {
52 data_ = make_object<ObjectPathPairNode>(std::move(lhs_path), std::move(rhs_path));
53}
54
55// Define the dispatch function here since primary user is in this file.
56bool ReflectionVTable::SEqualReduce(const Object* self, const Object* other,
57 SEqualReducer equal) const {
58 uint32_t tindex = self->type_index();
59 if (tindex >= fsequal_reduce_.size() || fsequal_reduce_[tindex] == nullptr) {
60 LOG(FATAL) << "TypeError: SEqualReduce of " << self->GetTypeKey()
61 << " is not registered via TVM_REGISTER_NODE_TYPE."
62 << " Did you forget to set _type_has_method_sequal_reduce=true?";
63 }
64 return fsequal_reduce_[tindex](self, other, equal);
65}
66
67struct SEqualReducer::PathTracingData {
68 ObjectPathPair current_paths;
69 ObjectRef lhs_object;
70 ObjectRef rhs_object;
71 Optional<ObjectPathPair>* first_mismatch;
72
73 ObjectPathPair GetPathsForAttrs(const ObjectRef& lhs, const ObjectRef& rhs) const {
74 Optional<String> lhs_attr_key = GetAttrKeyByAddress(lhs_object.get(), &lhs);
75 Optional<String> rhs_attr_key = GetAttrKeyByAddress(rhs_object.get(), &rhs);
76 return ObjectPathPair(current_paths->lhs_path->Attr(lhs_attr_key),
77 current_paths->rhs_path->Attr(rhs_attr_key));
78 }
79};
80
81bool SEqualReducer::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
82 if (tracing_data_ == nullptr) {
83 // Fast path: no tracing
84 return handler_->SEqualReduce(lhs, rhs, map_free_vars_, NullOpt);
85 }
86 return ObjectAttrsEqual(lhs, rhs, map_free_vars_, nullptr);
87}
88
89bool SEqualReducer::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
90 if (tracing_data_ == nullptr) {
91 // Fast path: no tracing
92 return handler_->SEqualReduce(lhs, rhs, true, NullOpt);
93 }
94 return ObjectAttrsEqual(lhs, rhs, true, nullptr);
95}
96
97/* static */ void SEqualReducer::GetPathsFromAttrAddressesAndStoreMismatch(
98 const void* lhs_address, const void* rhs_address, const PathTracingData* tracing_data) {
99 if (tracing_data != nullptr && !tracing_data->first_mismatch->defined()) {
100 Optional<String> lhs_attr_key =
101 GetAttrKeyByAddress(tracing_data->lhs_object.get(), lhs_address);
102 Optional<String> rhs_attr_key =
103 GetAttrKeyByAddress(tracing_data->rhs_object.get(), rhs_address);
104 *tracing_data->first_mismatch =
105 ObjectPathPair(tracing_data->current_paths->lhs_path->Attr(lhs_attr_key),
106 tracing_data->current_paths->rhs_path->Attr(rhs_attr_key));
107 }
108}
109
110template <typename T>
111/* static */ bool SEqualReducer::CompareAttributeValues(const T& lhs, const T& rhs,
112 const PathTracingData* tracing_data) {
113 if (BaseValueEqual()(lhs, rhs)) {
114 return true;
115 } else {
116 GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data);
117 return false;
118 }
119}
120
121bool SEqualReducer::operator()(const double& lhs, const double& rhs) const {
122 return CompareAttributeValues(lhs, rhs, tracing_data_);
123}
124
125bool SEqualReducer::operator()(const int64_t& lhs, const int64_t& rhs) const {
126 return CompareAttributeValues(lhs, rhs, tracing_data_);
127}
128
129bool SEqualReducer::operator()(const uint64_t& lhs, const uint64_t& rhs) const {
130 return CompareAttributeValues(lhs, rhs, tracing_data_);
131}
132
133bool SEqualReducer::operator()(const int& lhs, const int& rhs) const {
134 return CompareAttributeValues(lhs, rhs, tracing_data_);
135}
136
137bool SEqualReducer::operator()(const bool& lhs, const bool& rhs) const {
138 return CompareAttributeValues(lhs, rhs, tracing_data_);
139}
140
141bool SEqualReducer::operator()(const std::string& lhs, const std::string& rhs) const {
142 return CompareAttributeValues(lhs, rhs, tracing_data_);
143}
144
145bool SEqualReducer::operator()(const DataType& lhs, const DataType& rhs) const {
146 return CompareAttributeValues(lhs, rhs, tracing_data_);
147}
148
149bool SEqualReducer::EnumAttrsEqual(int lhs, int rhs, const void* lhs_address,
150 const void* rhs_address) const {
151 if (lhs == rhs) {
152 return true;
153 } else {
154 GetPathsFromAttrAddressesAndStoreMismatch(lhs_address, rhs_address, tracing_data_);
155 return false;
156 }
157}
158
159const ObjectPathPair& SEqualReducer::GetCurrentObjectPaths() const {
160 ICHECK(tracing_data_ != nullptr)
161 << "GetCurrentObjectPaths() can only be called when path tracing is enabled";
162 return tracing_data_->current_paths;
163}
164
165void SEqualReducer::RecordMismatchPaths(const ObjectPathPair& paths) const {
166 ICHECK(tracing_data_ != nullptr)
167 << "RecordMismatchPaths() can only be called when path tracing is enabled";
168 if (!tracing_data_->first_mismatch->defined()) {
169 *tracing_data_->first_mismatch = paths;
170 }
171}
172
173bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
174 const ObjectPathPair* paths) const {
175 if (tracing_data_ == nullptr) {
176 // Fast path: no tracing
177 return handler_->SEqualReduce(lhs, rhs, map_free_vars, NullOpt);
178 }
179
180 // Slow path: tracing object paths for better error reporting
181
182 ObjectPathPair new_paths = paths == nullptr ? tracing_data_->GetPathsForAttrs(lhs, rhs) : *paths;
183
184 if (handler_->SEqualReduce(lhs, rhs, map_free_vars, new_paths)) {
185 return true;
186 } else {
187 if (!tracing_data_->first_mismatch->defined()) {
188 *tracing_data_->first_mismatch = new_paths;
189 }
190 return false;
191 }
192}
193
194/*!
195 * \brief A non recursive stack based SEqual handler that can remaps vars.
196 *
197 * This handler pushs the Object equality cases into a stack, and
198 * traverses the stack to expand the necessary children that need to be checked.
199 *
200 * The order of SEqual being called is the same as the order as if we
201 * eagerly do recursive calls in SEqualReduce.
202 */
203class SEqualHandlerDefault::Impl {
204 public:
205 Impl(SEqualHandlerDefault* parent, bool assert_mode, Optional<ObjectPathPair>* first_mismatch,
206 bool defer_fails)
207 : parent_(parent),
208 assert_mode_(assert_mode),
209 first_mismatch_(first_mismatch),
210 defer_fails_(defer_fails) {}
211
212 bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
213 const Optional<ObjectPathPair>& current_paths) {
214 // We cannot use check lhs.same_as(rhs) to check equality.
215 // if we choose to enable var remapping.
216 //
217 // Counter example below (%x, %y) are shared vars
218 // between the two functions(possibly before/after rewriting).
219 //
220 // - function0: fn (%x, %y) { %x + %y }
221 // - function1. fn (%y, %x) { %x + %y }
222 //
223 // Because we choose to enable var remapping,
224 // %x is mapped to %y, and %y is mapped to %x,
225 // the body of the function no longer means the same thing.
226 //
227 // Take away: We can either choose only compare Var by address,
228 // in which case we can use same_as for quick checking,
229 // or we have to run deep comparison and avoid to use same_as checks.
230 auto run = [=]() {
231 if (!lhs.defined() && !rhs.defined()) return true;
232 if (!lhs.defined() && rhs.defined()) return false;
233 if (!rhs.defined() && lhs.defined()) return false;
234 if (lhs->type_index() != rhs->type_index()) return false;
235 auto it = equal_map_lhs_.find(lhs);
236 if (it != equal_map_lhs_.end()) {
237 return it->second.same_as(rhs);
238 }
239 if (equal_map_rhs_.count(rhs)) return false;
240
241 // need to push to pending tasks in this case
242 pending_tasks_.emplace_back(lhs, rhs, map_free_vars, current_paths);
243 return true;
244 };
245 return CheckResult(run(), lhs, rhs, current_paths);
246 }
247
248 void DeferFail(const ObjectPathPair& mismatch_paths) {
249 pending_tasks_.emplace_back(Task::ForceFailTag{}, mismatch_paths);
250 }
251
252 bool IsFailDeferralEnabled() { return defer_fails_; }
253
254 void MarkGraphNode() {
255 // need to push to pending tasks in this case
256 ICHECK(!allow_push_to_stack_ && !task_stack_.empty());
257 task_stack_.back().graph_equal = true;
258 }
259
260 ObjectRef MapLhsToRhs(const ObjectRef& lhs) {
261 auto it = equal_map_lhs_.find(lhs);
262 if (it != equal_map_lhs_.end()) return it->second;
263 return ObjectRef(nullptr);
264 }
265
266 // Function that implements actual equality check.
267 bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
268 if (!lhs.defined() && !rhs.defined()) return true;
269 task_stack_.clear();
270 pending_tasks_.clear();
271 equal_map_lhs_.clear();
272 equal_map_rhs_.clear();
273 root_lhs_ = lhs;
274 root_rhs_ = rhs;
275
276 Optional<ObjectPathPair> current_paths;
277 if (IsPathTracingEnabled()) {
278 auto root_path = ObjectPath::Root();
279 current_paths = ObjectPathPair(root_path, root_path);
280 }
281 if (!SEqualReduce(lhs, rhs, map_free_vars, current_paths)) {
282 return false;
283 }
284
285 ICHECK_EQ(pending_tasks_.size(), 1U);
286 ICHECK(allow_push_to_stack_);
287 task_stack_.emplace_back(std::move(pending_tasks_.back()));
288 pending_tasks_.clear();
289 return RunTasks();
290 }
291
292 // The default equal as registered in the structural equal vtable.
293 bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
294 const Optional<ObjectPathPair>& current_paths) {
295 auto compute = [=]() {
296 ICHECK(lhs.defined() && rhs.defined() && lhs->type_index() == rhs->type_index());
297 // skip entries that already have equality maps.
298 auto it = equal_map_lhs_.find(lhs);
299 if (it != equal_map_lhs_.end()) {
300 return it->second.same_as(rhs);
301 }
302 if (equal_map_rhs_.count(rhs)) return false;
303
304 if (!IsPathTracingEnabled()) {
305 return vtable_->SEqualReduce(lhs.get(), rhs.get(),
306 SEqualReducer(parent_, nullptr, map_free_vars));
307 } else {
308 PathTracingData tracing_data = {current_paths.value(), lhs, rhs, first_mismatch_};
309 return vtable_->SEqualReduce(lhs.get(), rhs.get(),
310 SEqualReducer(parent_, &tracing_data, map_free_vars));
311 }
312 };
313 return CheckResult(compute(), lhs, rhs, current_paths);
314 }
315
316 protected:
317 // Check the result.
318 bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs,
319 const Optional<ObjectPathPair>& current_paths) {
320 if (IsPathTracingEnabled() && !result && !first_mismatch_->defined()) {
321 *first_mismatch_ = current_paths;
322 }
323 if (assert_mode_ && !result) {
324 std::ostringstream oss;
325 oss << "ValueError: StructuralEqual check failed, caused by lhs";
326 if (first_mismatch_->defined()) {
327 oss << " at " << first_mismatch_->value()->lhs_path;
328 if (root_lhs_.defined()) {
329 PrinterConfig cfg;
330 cfg->syntax_sugar = false;
331 cfg->path_to_underline.push_back(first_mismatch_->value()->lhs_path);
332 // The TVMScriptPrinter::Script will fallback to Repr printer,
333 // if the root node to print is not supported yet,
334 // e.g. Relay nodes, ArrayNode, MapNode, etc.
335 oss << ":" << std::endl << TVMScriptPrinter::Script(root_lhs_.value(), cfg);
336 }
337 } else {
338 oss << ":" << std::endl << lhs;
339 }
340 oss << std::endl << "and rhs";
341 if (first_mismatch_->defined()) {
342 oss << " at " << first_mismatch_->value()->rhs_path;
343 if (root_rhs_.defined()) {
344 PrinterConfig cfg;
345 cfg->syntax_sugar = false;
346 cfg->path_to_underline.push_back(first_mismatch_->value()->rhs_path);
347 // The TVMScriptPrinter::Script will fallback to Repr printer,
348 // if the root node to print is not supported yet,
349 // e.g. Relay nodes, ArrayNode, MapNode, etc.
350 oss << ":" << std::endl << TVMScriptPrinter::Script(root_rhs_.value(), cfg);
351 }
352 } else {
353 oss << ":" << std::endl << rhs;
354 }
355 LOG(FATAL) << oss.str();
356 }
357 return result;
358 }
359 /*!
360 * \brief Run tasks until the stack reaches the stack begin
361 * \param stack_begin The expected beginning of the stack.
362 * \return The checks we encountered throughout the process.
363 */
364 bool RunTasks() {
365 while (task_stack_.size() != 0) {
366 // Caution: entry becomes invalid when the stack changes
367 auto& entry = task_stack_.back();
368
369 if (entry.force_fail) {
370 if (IsPathTracingEnabled() && !first_mismatch_->defined()) {
371 *first_mismatch_ = entry.current_paths;
372 }
373 return false;
374 }
375
376 if (entry.children_expanded) {
377 // When all the children has expanded and visited.
378 // This means all the condition checks for
379 // the current entry has been passed
380 // We can safely mark lhs and rhs as equal to each other.
381 auto it = equal_map_lhs_.find(entry.lhs);
382 if (it != equal_map_lhs_.end()) {
383 ICHECK(it->second.same_as(entry.rhs));
384 }
385 // create the map if the quality is graph equal.
386 if (entry.graph_equal) {
387 equal_map_lhs_[entry.lhs] = entry.rhs;
388 equal_map_rhs_[entry.rhs] = entry.lhs;
389 }
390 task_stack_.pop_back();
391 } else {
392 // mark before expand
393 // Important: because entry becomes invalid when stack changes.
394 entry.children_expanded = true;
395 // Expand the objects
396 // The SEqual of the object can call into this->SEqualReduce
397 // which populates the pending tasks.
398 ICHECK_EQ(pending_tasks_.size(), 0U);
399 allow_push_to_stack_ = false;
400 if (!parent_->DispatchSEqualReduce(entry.lhs, entry.rhs, entry.map_free_vars,
401 entry.current_paths))
402 return false;
403 allow_push_to_stack_ = true;
404 // Push pending tasks in reverse order, so earlier tasks get to
405 // expand first in the stack
406 while (pending_tasks_.size() != 0) {
407 task_stack_.emplace_back(std::move(pending_tasks_.back()));
408 pending_tasks_.pop_back();
409 }
410 }
411 }
412 return true;
413 }
414
415 private:
416 /*! \brief Pending reduce tasks. */
417 struct Task {
418 /*! \brief The lhs operand to be compared. */
419 ObjectRef lhs;
420 /*! \brief The rhs operand to be compared. */
421 ObjectRef rhs;
422 /*! \brief If path tracing is enabled, paths taken so far from the root to `lhs` and `rhs`
423 * objects. */
424 Optional<ObjectPathPair> current_paths;
425 /*! \brief The map free var argument. */
426 bool map_free_vars;
427 /*! \brief Whether the children has been expanded via SEqualReduce */
428 bool children_expanded{false};
429 /*! \brief whether the task is about graph equality(need remap). */
430 bool graph_equal{false};
431 /*! \brief whether the task should return "false" without actually comparing anything */
432 bool force_fail{false};
433
434 Task() = default;
435 Task(ObjectRef lhs, ObjectRef rhs, bool map_free_vars, Optional<ObjectPathPair> current_paths)
436 : lhs(lhs),
437 rhs(rhs),
438 current_paths(std::move(current_paths)),
439 map_free_vars(map_free_vars) {}
440
441 struct ForceFailTag {}; // dispatch tag for the constructor below
442 Task(ForceFailTag, const ObjectPathPair& current_paths)
443 : current_paths(current_paths), force_fail(true) {}
444 };
445
446 bool IsPathTracingEnabled() const { return first_mismatch_ != nullptr; }
447
448 // The owner of this impl
449 SEqualHandlerDefault* parent_;
450 // list of pending tasks to be pushed to the stack.
451 std::vector<Task> pending_tasks_;
452 // Internal task stack to executed the task.
453 std::vector<Task> task_stack_;
454 // Whether we allow push to stack.
455 bool allow_push_to_stack_{true};
456 // If in assert mode, must return true, and will throw error otherwise.
457 bool assert_mode_{false};
458 // Location to store the paths to the first detected mismatch, or nullptr to disable path
459 // tracing.
460 Optional<ObjectPathPair>* first_mismatch_;
461 // reflection vtable
462 ReflectionVTable* vtable_ = ReflectionVTable::Global();
463 // map from lhs to rhs
464 std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> equal_map_lhs_;
465 // map from rhs to lhs
466 std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> equal_map_rhs_;
467 // root lhs for result printing
468 Optional<ObjectRef> root_lhs_;
469 // root rhs for result printing
470 Optional<ObjectRef> root_rhs_;
471 // whether to defer fails
472 bool defer_fails_;
473};
474
475SEqualHandlerDefault::SEqualHandlerDefault(bool assert_mode,
476 Optional<ObjectPathPair>* first_mismatch,
477 bool defer_fails) {
478 impl = new Impl(this, assert_mode, first_mismatch, defer_fails);
479}
480
481SEqualHandlerDefault::~SEqualHandlerDefault() { delete impl; }
482
483bool SEqualHandlerDefault::SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs,
484 bool map_free_vars,
485 const Optional<ObjectPathPair>& current_paths) {
486 return impl->SEqualReduce(lhs, rhs, map_free_vars, current_paths);
487}
488
489void SEqualHandlerDefault::DeferFail(const ObjectPathPair& mismatch_paths) {
490 impl->DeferFail(mismatch_paths);
491}
492
493bool SEqualHandlerDefault::IsFailDeferralEnabled() { return impl->IsFailDeferralEnabled(); }
494
495ObjectRef SEqualHandlerDefault::MapLhsToRhs(const ObjectRef& lhs) { return impl->MapLhsToRhs(lhs); }
496
497void SEqualHandlerDefault::MarkGraphNode() { impl->MarkGraphNode(); }
498
499bool SEqualHandlerDefault::Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
500 return impl->Equal(lhs, rhs, map_free_vars);
501}
502
503bool SEqualHandlerDefault::DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs,
504 bool map_free_vars,
505 const Optional<ObjectPathPair>& current_paths) {
506 return impl->DispatchSEqualReduce(lhs, rhs, map_free_vars, current_paths);
507}
508
509TVM_REGISTER_GLOBAL("node.StructuralEqual")
510 .set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool assert_mode,
511 bool map_free_vars) {
512 Optional<ObjectPathPair> first_mismatch;
513 return SEqualHandlerDefault(assert_mode, &first_mismatch, false)
514 .Equal(lhs, rhs, map_free_vars);
515 });
516
517TVM_REGISTER_GLOBAL("node.GetFirstStructuralMismatch")
518 .set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
519 Optional<ObjectPathPair> first_mismatch;
520 bool equal =
521 SEqualHandlerDefault(false, &first_mismatch, true).Equal(lhs, rhs, map_free_vars);
522 ICHECK(equal == !first_mismatch.defined());
523 return first_mismatch;
524 });
525
526bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
527 return SEqualHandlerDefault(false, nullptr, false).Equal(lhs, rhs, false);
528}
529
530bool NDArrayEqual(const runtime::NDArray::Container* lhs, const runtime::NDArray::Container* rhs,
531 SEqualReducer equal, bool compare_data) {
532 if (lhs == rhs) return true;
533
534 auto ldt = lhs->dl_tensor.dtype;
535 auto rdt = rhs->dl_tensor.dtype;
536 ICHECK_EQ(lhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor";
537 ICHECK_EQ(rhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor";
538 ICHECK(runtime::IsContiguous(lhs->dl_tensor)) << "Can only compare contiguous tensor";
539 ICHECK(runtime::IsContiguous(rhs->dl_tensor)) << "Can only compare contiguous tensor";
540
541 if (lhs->dl_tensor.ndim != rhs->dl_tensor.ndim) return false;
542 for (int i = 0; i < lhs->dl_tensor.ndim; ++i) {
543 if (!equal(lhs->dl_tensor.shape[i], rhs->dl_tensor.shape[i])) return false;
544 }
545 if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
546 size_t data_size = runtime::GetDataSize(lhs->dl_tensor);
547 if (compare_data) {
548 return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0;
549 } else {
550 return true;
551 }
552 } else {
553 return false;
554 }
555}
556
557bool NDArrayContainerTrait::SEqualReduce(const runtime::NDArray::Container* lhs,
558 const runtime::NDArray::Container* rhs,
559 SEqualReducer equal) {
560 return NDArrayEqual(lhs, rhs, equal, true);
561}
562
563} // namespace tvm
564