1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | /*! |
20 | * \file src/node/structural_equal.cc |
21 | */ |
22 | #include <tvm/ir/module.h> |
23 | #include <tvm/node/functor.h> |
24 | #include <tvm/node/node.h> |
25 | #include <tvm/node/object_path.h> |
26 | #include <tvm/node/reflection.h> |
27 | #include <tvm/node/structural_equal.h> |
28 | #include <tvm/runtime/registry.h> |
29 | |
30 | #include <unordered_map> |
31 | |
32 | #include "ndarray_hash_equal.h" |
33 | |
34 | namespace tvm { |
35 | |
36 | TVM_REGISTER_OBJECT_TYPE(ObjectPathPairNode); |
37 | |
38 | TVM_REGISTER_GLOBAL("node.ObjectPathPairLhsPath" ) |
39 | .set_body_typed([](const ObjectPathPair& object_path_pair) { |
40 | return object_path_pair->lhs_path; |
41 | }); |
42 | |
43 | TVM_REGISTER_GLOBAL("node.ObjectPathPairRhsPath" ) |
44 | .set_body_typed([](const ObjectPathPair& object_path_pair) { |
45 | return object_path_pair->rhs_path; |
46 | }); |
47 | |
48 | ObjectPathPairNode::ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path) |
49 | : lhs_path(std::move(lhs_path)), rhs_path(std::move(rhs_path)) {} |
50 | |
51 | ObjectPathPair::ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path) { |
52 | data_ = make_object<ObjectPathPairNode>(std::move(lhs_path), std::move(rhs_path)); |
53 | } |
54 | |
55 | // Define the dispatch function here since primary user is in this file. |
56 | bool ReflectionVTable::SEqualReduce(const Object* self, const Object* other, |
57 | SEqualReducer equal) const { |
58 | uint32_t tindex = self->type_index(); |
59 | if (tindex >= fsequal_reduce_.size() || fsequal_reduce_[tindex] == nullptr) { |
60 | LOG(FATAL) << "TypeError: SEqualReduce of " << self->GetTypeKey() |
61 | << " is not registered via TVM_REGISTER_NODE_TYPE." |
62 | << " Did you forget to set _type_has_method_sequal_reduce=true?" ; |
63 | } |
64 | return fsequal_reduce_[tindex](self, other, equal); |
65 | } |
66 | |
67 | struct SEqualReducer::PathTracingData { |
68 | ObjectPathPair current_paths; |
69 | ObjectRef lhs_object; |
70 | ObjectRef rhs_object; |
71 | Optional<ObjectPathPair>* first_mismatch; |
72 | |
73 | ObjectPathPair GetPathsForAttrs(const ObjectRef& lhs, const ObjectRef& rhs) const { |
74 | Optional<String> lhs_attr_key = GetAttrKeyByAddress(lhs_object.get(), &lhs); |
75 | Optional<String> rhs_attr_key = GetAttrKeyByAddress(rhs_object.get(), &rhs); |
76 | return ObjectPathPair(current_paths->lhs_path->Attr(lhs_attr_key), |
77 | current_paths->rhs_path->Attr(rhs_attr_key)); |
78 | } |
79 | }; |
80 | |
81 | bool SEqualReducer::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const { |
82 | if (tracing_data_ == nullptr) { |
83 | // Fast path: no tracing |
84 | return handler_->SEqualReduce(lhs, rhs, map_free_vars_, NullOpt); |
85 | } |
86 | return ObjectAttrsEqual(lhs, rhs, map_free_vars_, nullptr); |
87 | } |
88 | |
89 | bool SEqualReducer::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) { |
90 | if (tracing_data_ == nullptr) { |
91 | // Fast path: no tracing |
92 | return handler_->SEqualReduce(lhs, rhs, true, NullOpt); |
93 | } |
94 | return ObjectAttrsEqual(lhs, rhs, true, nullptr); |
95 | } |
96 | |
97 | /* static */ void SEqualReducer::GetPathsFromAttrAddressesAndStoreMismatch( |
98 | const void* lhs_address, const void* rhs_address, const PathTracingData* tracing_data) { |
99 | if (tracing_data != nullptr && !tracing_data->first_mismatch->defined()) { |
100 | Optional<String> lhs_attr_key = |
101 | GetAttrKeyByAddress(tracing_data->lhs_object.get(), lhs_address); |
102 | Optional<String> rhs_attr_key = |
103 | GetAttrKeyByAddress(tracing_data->rhs_object.get(), rhs_address); |
104 | *tracing_data->first_mismatch = |
105 | ObjectPathPair(tracing_data->current_paths->lhs_path->Attr(lhs_attr_key), |
106 | tracing_data->current_paths->rhs_path->Attr(rhs_attr_key)); |
107 | } |
108 | } |
109 | |
110 | template <typename T> |
111 | /* static */ bool SEqualReducer::CompareAttributeValues(const T& lhs, const T& rhs, |
112 | const PathTracingData* tracing_data) { |
113 | if (BaseValueEqual()(lhs, rhs)) { |
114 | return true; |
115 | } else { |
116 | GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data); |
117 | return false; |
118 | } |
119 | } |
120 | |
121 | bool SEqualReducer::operator()(const double& lhs, const double& rhs) const { |
122 | return CompareAttributeValues(lhs, rhs, tracing_data_); |
123 | } |
124 | |
125 | bool SEqualReducer::operator()(const int64_t& lhs, const int64_t& rhs) const { |
126 | return CompareAttributeValues(lhs, rhs, tracing_data_); |
127 | } |
128 | |
129 | bool SEqualReducer::operator()(const uint64_t& lhs, const uint64_t& rhs) const { |
130 | return CompareAttributeValues(lhs, rhs, tracing_data_); |
131 | } |
132 | |
133 | bool SEqualReducer::operator()(const int& lhs, const int& rhs) const { |
134 | return CompareAttributeValues(lhs, rhs, tracing_data_); |
135 | } |
136 | |
137 | bool SEqualReducer::operator()(const bool& lhs, const bool& rhs) const { |
138 | return CompareAttributeValues(lhs, rhs, tracing_data_); |
139 | } |
140 | |
141 | bool SEqualReducer::operator()(const std::string& lhs, const std::string& rhs) const { |
142 | return CompareAttributeValues(lhs, rhs, tracing_data_); |
143 | } |
144 | |
145 | bool SEqualReducer::operator()(const DataType& lhs, const DataType& rhs) const { |
146 | return CompareAttributeValues(lhs, rhs, tracing_data_); |
147 | } |
148 | |
149 | bool SEqualReducer::EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, |
150 | const void* rhs_address) const { |
151 | if (lhs == rhs) { |
152 | return true; |
153 | } else { |
154 | GetPathsFromAttrAddressesAndStoreMismatch(lhs_address, rhs_address, tracing_data_); |
155 | return false; |
156 | } |
157 | } |
158 | |
159 | const ObjectPathPair& SEqualReducer::GetCurrentObjectPaths() const { |
160 | ICHECK(tracing_data_ != nullptr) |
161 | << "GetCurrentObjectPaths() can only be called when path tracing is enabled" ; |
162 | return tracing_data_->current_paths; |
163 | } |
164 | |
165 | void SEqualReducer::RecordMismatchPaths(const ObjectPathPair& paths) const { |
166 | ICHECK(tracing_data_ != nullptr) |
167 | << "RecordMismatchPaths() can only be called when path tracing is enabled" ; |
168 | if (!tracing_data_->first_mismatch->defined()) { |
169 | *tracing_data_->first_mismatch = paths; |
170 | } |
171 | } |
172 | |
173 | bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, |
174 | const ObjectPathPair* paths) const { |
175 | if (tracing_data_ == nullptr) { |
176 | // Fast path: no tracing |
177 | return handler_->SEqualReduce(lhs, rhs, map_free_vars, NullOpt); |
178 | } |
179 | |
180 | // Slow path: tracing object paths for better error reporting |
181 | |
182 | ObjectPathPair new_paths = paths == nullptr ? tracing_data_->GetPathsForAttrs(lhs, rhs) : *paths; |
183 | |
184 | if (handler_->SEqualReduce(lhs, rhs, map_free_vars, new_paths)) { |
185 | return true; |
186 | } else { |
187 | if (!tracing_data_->first_mismatch->defined()) { |
188 | *tracing_data_->first_mismatch = new_paths; |
189 | } |
190 | return false; |
191 | } |
192 | } |
193 | |
194 | /*! |
195 | * \brief A non recursive stack based SEqual handler that can remaps vars. |
196 | * |
197 | * This handler pushs the Object equality cases into a stack, and |
198 | * traverses the stack to expand the necessary children that need to be checked. |
199 | * |
200 | * The order of SEqual being called is the same as the order as if we |
201 | * eagerly do recursive calls in SEqualReduce. |
202 | */ |
203 | class SEqualHandlerDefault::Impl { |
204 | public: |
205 | Impl(SEqualHandlerDefault* parent, bool assert_mode, Optional<ObjectPathPair>* first_mismatch, |
206 | bool defer_fails) |
207 | : parent_(parent), |
208 | assert_mode_(assert_mode), |
209 | first_mismatch_(first_mismatch), |
210 | defer_fails_(defer_fails) {} |
211 | |
212 | bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, |
213 | const Optional<ObjectPathPair>& current_paths) { |
214 | // We cannot use check lhs.same_as(rhs) to check equality. |
215 | // if we choose to enable var remapping. |
216 | // |
217 | // Counter example below (%x, %y) are shared vars |
218 | // between the two functions(possibly before/after rewriting). |
219 | // |
220 | // - function0: fn (%x, %y) { %x + %y } |
221 | // - function1. fn (%y, %x) { %x + %y } |
222 | // |
223 | // Because we choose to enable var remapping, |
224 | // %x is mapped to %y, and %y is mapped to %x, |
225 | // the body of the function no longer means the same thing. |
226 | // |
227 | // Take away: We can either choose only compare Var by address, |
228 | // in which case we can use same_as for quick checking, |
229 | // or we have to run deep comparison and avoid to use same_as checks. |
230 | auto run = [=]() { |
231 | if (!lhs.defined() && !rhs.defined()) return true; |
232 | if (!lhs.defined() && rhs.defined()) return false; |
233 | if (!rhs.defined() && lhs.defined()) return false; |
234 | if (lhs->type_index() != rhs->type_index()) return false; |
235 | auto it = equal_map_lhs_.find(lhs); |
236 | if (it != equal_map_lhs_.end()) { |
237 | return it->second.same_as(rhs); |
238 | } |
239 | if (equal_map_rhs_.count(rhs)) return false; |
240 | |
241 | // need to push to pending tasks in this case |
242 | pending_tasks_.emplace_back(lhs, rhs, map_free_vars, current_paths); |
243 | return true; |
244 | }; |
245 | return CheckResult(run(), lhs, rhs, current_paths); |
246 | } |
247 | |
248 | void DeferFail(const ObjectPathPair& mismatch_paths) { |
249 | pending_tasks_.emplace_back(Task::ForceFailTag{}, mismatch_paths); |
250 | } |
251 | |
252 | bool IsFailDeferralEnabled() { return defer_fails_; } |
253 | |
254 | void MarkGraphNode() { |
255 | // need to push to pending tasks in this case |
256 | ICHECK(!allow_push_to_stack_ && !task_stack_.empty()); |
257 | task_stack_.back().graph_equal = true; |
258 | } |
259 | |
260 | ObjectRef MapLhsToRhs(const ObjectRef& lhs) { |
261 | auto it = equal_map_lhs_.find(lhs); |
262 | if (it != equal_map_lhs_.end()) return it->second; |
263 | return ObjectRef(nullptr); |
264 | } |
265 | |
266 | // Function that implements actual equality check. |
267 | bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) { |
268 | if (!lhs.defined() && !rhs.defined()) return true; |
269 | task_stack_.clear(); |
270 | pending_tasks_.clear(); |
271 | equal_map_lhs_.clear(); |
272 | equal_map_rhs_.clear(); |
273 | root_lhs_ = lhs; |
274 | root_rhs_ = rhs; |
275 | |
276 | Optional<ObjectPathPair> current_paths; |
277 | if (IsPathTracingEnabled()) { |
278 | auto root_path = ObjectPath::Root(); |
279 | current_paths = ObjectPathPair(root_path, root_path); |
280 | } |
281 | if (!SEqualReduce(lhs, rhs, map_free_vars, current_paths)) { |
282 | return false; |
283 | } |
284 | |
285 | ICHECK_EQ(pending_tasks_.size(), 1U); |
286 | ICHECK(allow_push_to_stack_); |
287 | task_stack_.emplace_back(std::move(pending_tasks_.back())); |
288 | pending_tasks_.clear(); |
289 | return RunTasks(); |
290 | } |
291 | |
292 | // The default equal as registered in the structural equal vtable. |
293 | bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, |
294 | const Optional<ObjectPathPair>& current_paths) { |
295 | auto compute = [=]() { |
296 | ICHECK(lhs.defined() && rhs.defined() && lhs->type_index() == rhs->type_index()); |
297 | // skip entries that already have equality maps. |
298 | auto it = equal_map_lhs_.find(lhs); |
299 | if (it != equal_map_lhs_.end()) { |
300 | return it->second.same_as(rhs); |
301 | } |
302 | if (equal_map_rhs_.count(rhs)) return false; |
303 | |
304 | if (!IsPathTracingEnabled()) { |
305 | return vtable_->SEqualReduce(lhs.get(), rhs.get(), |
306 | SEqualReducer(parent_, nullptr, map_free_vars)); |
307 | } else { |
308 | PathTracingData tracing_data = {current_paths.value(), lhs, rhs, first_mismatch_}; |
309 | return vtable_->SEqualReduce(lhs.get(), rhs.get(), |
310 | SEqualReducer(parent_, &tracing_data, map_free_vars)); |
311 | } |
312 | }; |
313 | return CheckResult(compute(), lhs, rhs, current_paths); |
314 | } |
315 | |
316 | protected: |
317 | // Check the result. |
318 | bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs, |
319 | const Optional<ObjectPathPair>& current_paths) { |
320 | if (IsPathTracingEnabled() && !result && !first_mismatch_->defined()) { |
321 | *first_mismatch_ = current_paths; |
322 | } |
323 | if (assert_mode_ && !result) { |
324 | std::ostringstream oss; |
325 | oss << "ValueError: StructuralEqual check failed, caused by lhs" ; |
326 | if (first_mismatch_->defined()) { |
327 | oss << " at " << first_mismatch_->value()->lhs_path; |
328 | if (root_lhs_.defined()) { |
329 | PrinterConfig cfg; |
330 | cfg->syntax_sugar = false; |
331 | cfg->path_to_underline.push_back(first_mismatch_->value()->lhs_path); |
332 | // The TVMScriptPrinter::Script will fallback to Repr printer, |
333 | // if the root node to print is not supported yet, |
334 | // e.g. Relay nodes, ArrayNode, MapNode, etc. |
335 | oss << ":" << std::endl << TVMScriptPrinter::Script(root_lhs_.value(), cfg); |
336 | } |
337 | } else { |
338 | oss << ":" << std::endl << lhs; |
339 | } |
340 | oss << std::endl << "and rhs" ; |
341 | if (first_mismatch_->defined()) { |
342 | oss << " at " << first_mismatch_->value()->rhs_path; |
343 | if (root_rhs_.defined()) { |
344 | PrinterConfig cfg; |
345 | cfg->syntax_sugar = false; |
346 | cfg->path_to_underline.push_back(first_mismatch_->value()->rhs_path); |
347 | // The TVMScriptPrinter::Script will fallback to Repr printer, |
348 | // if the root node to print is not supported yet, |
349 | // e.g. Relay nodes, ArrayNode, MapNode, etc. |
350 | oss << ":" << std::endl << TVMScriptPrinter::Script(root_rhs_.value(), cfg); |
351 | } |
352 | } else { |
353 | oss << ":" << std::endl << rhs; |
354 | } |
355 | LOG(FATAL) << oss.str(); |
356 | } |
357 | return result; |
358 | } |
359 | /*! |
360 | * \brief Run tasks until the stack reaches the stack begin |
361 | * \param stack_begin The expected beginning of the stack. |
362 | * \return The checks we encountered throughout the process. |
363 | */ |
364 | bool RunTasks() { |
365 | while (task_stack_.size() != 0) { |
366 | // Caution: entry becomes invalid when the stack changes |
367 | auto& entry = task_stack_.back(); |
368 | |
369 | if (entry.force_fail) { |
370 | if (IsPathTracingEnabled() && !first_mismatch_->defined()) { |
371 | *first_mismatch_ = entry.current_paths; |
372 | } |
373 | return false; |
374 | } |
375 | |
376 | if (entry.children_expanded) { |
377 | // When all the children has expanded and visited. |
378 | // This means all the condition checks for |
379 | // the current entry has been passed |
380 | // We can safely mark lhs and rhs as equal to each other. |
381 | auto it = equal_map_lhs_.find(entry.lhs); |
382 | if (it != equal_map_lhs_.end()) { |
383 | ICHECK(it->second.same_as(entry.rhs)); |
384 | } |
385 | // create the map if the quality is graph equal. |
386 | if (entry.graph_equal) { |
387 | equal_map_lhs_[entry.lhs] = entry.rhs; |
388 | equal_map_rhs_[entry.rhs] = entry.lhs; |
389 | } |
390 | task_stack_.pop_back(); |
391 | } else { |
392 | // mark before expand |
393 | // Important: because entry becomes invalid when stack changes. |
394 | entry.children_expanded = true; |
395 | // Expand the objects |
396 | // The SEqual of the object can call into this->SEqualReduce |
397 | // which populates the pending tasks. |
398 | ICHECK_EQ(pending_tasks_.size(), 0U); |
399 | allow_push_to_stack_ = false; |
400 | if (!parent_->DispatchSEqualReduce(entry.lhs, entry.rhs, entry.map_free_vars, |
401 | entry.current_paths)) |
402 | return false; |
403 | allow_push_to_stack_ = true; |
404 | // Push pending tasks in reverse order, so earlier tasks get to |
405 | // expand first in the stack |
406 | while (pending_tasks_.size() != 0) { |
407 | task_stack_.emplace_back(std::move(pending_tasks_.back())); |
408 | pending_tasks_.pop_back(); |
409 | } |
410 | } |
411 | } |
412 | return true; |
413 | } |
414 | |
415 | private: |
416 | /*! \brief Pending reduce tasks. */ |
417 | struct Task { |
418 | /*! \brief The lhs operand to be compared. */ |
419 | ObjectRef lhs; |
420 | /*! \brief The rhs operand to be compared. */ |
421 | ObjectRef rhs; |
422 | /*! \brief If path tracing is enabled, paths taken so far from the root to `lhs` and `rhs` |
423 | * objects. */ |
424 | Optional<ObjectPathPair> current_paths; |
425 | /*! \brief The map free var argument. */ |
426 | bool map_free_vars; |
427 | /*! \brief Whether the children has been expanded via SEqualReduce */ |
428 | bool children_expanded{false}; |
429 | /*! \brief whether the task is about graph equality(need remap). */ |
430 | bool graph_equal{false}; |
431 | /*! \brief whether the task should return "false" without actually comparing anything */ |
432 | bool force_fail{false}; |
433 | |
434 | Task() = default; |
435 | Task(ObjectRef lhs, ObjectRef rhs, bool map_free_vars, Optional<ObjectPathPair> current_paths) |
436 | : lhs(lhs), |
437 | rhs(rhs), |
438 | current_paths(std::move(current_paths)), |
439 | map_free_vars(map_free_vars) {} |
440 | |
441 | struct ForceFailTag {}; // dispatch tag for the constructor below |
442 | Task(ForceFailTag, const ObjectPathPair& current_paths) |
443 | : current_paths(current_paths), force_fail(true) {} |
444 | }; |
445 | |
446 | bool IsPathTracingEnabled() const { return first_mismatch_ != nullptr; } |
447 | |
448 | // The owner of this impl |
449 | SEqualHandlerDefault* parent_; |
450 | // list of pending tasks to be pushed to the stack. |
451 | std::vector<Task> pending_tasks_; |
452 | // Internal task stack to executed the task. |
453 | std::vector<Task> task_stack_; |
454 | // Whether we allow push to stack. |
455 | bool allow_push_to_stack_{true}; |
456 | // If in assert mode, must return true, and will throw error otherwise. |
457 | bool assert_mode_{false}; |
458 | // Location to store the paths to the first detected mismatch, or nullptr to disable path |
459 | // tracing. |
460 | Optional<ObjectPathPair>* first_mismatch_; |
461 | // reflection vtable |
462 | ReflectionVTable* vtable_ = ReflectionVTable::Global(); |
463 | // map from lhs to rhs |
464 | std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> equal_map_lhs_; |
465 | // map from rhs to lhs |
466 | std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> equal_map_rhs_; |
467 | // root lhs for result printing |
468 | Optional<ObjectRef> root_lhs_; |
469 | // root rhs for result printing |
470 | Optional<ObjectRef> root_rhs_; |
471 | // whether to defer fails |
472 | bool defer_fails_; |
473 | }; |
474 | |
475 | SEqualHandlerDefault::SEqualHandlerDefault(bool assert_mode, |
476 | Optional<ObjectPathPair>* first_mismatch, |
477 | bool defer_fails) { |
478 | impl = new Impl(this, assert_mode, first_mismatch, defer_fails); |
479 | } |
480 | |
481 | SEqualHandlerDefault::~SEqualHandlerDefault() { delete impl; } |
482 | |
483 | bool SEqualHandlerDefault::SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, |
484 | bool map_free_vars, |
485 | const Optional<ObjectPathPair>& current_paths) { |
486 | return impl->SEqualReduce(lhs, rhs, map_free_vars, current_paths); |
487 | } |
488 | |
489 | void SEqualHandlerDefault::DeferFail(const ObjectPathPair& mismatch_paths) { |
490 | impl->DeferFail(mismatch_paths); |
491 | } |
492 | |
493 | bool SEqualHandlerDefault::IsFailDeferralEnabled() { return impl->IsFailDeferralEnabled(); } |
494 | |
495 | ObjectRef SEqualHandlerDefault::MapLhsToRhs(const ObjectRef& lhs) { return impl->MapLhsToRhs(lhs); } |
496 | |
497 | void SEqualHandlerDefault::MarkGraphNode() { impl->MarkGraphNode(); } |
498 | |
499 | bool SEqualHandlerDefault::Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) { |
500 | return impl->Equal(lhs, rhs, map_free_vars); |
501 | } |
502 | |
503 | bool SEqualHandlerDefault::DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, |
504 | bool map_free_vars, |
505 | const Optional<ObjectPathPair>& current_paths) { |
506 | return impl->DispatchSEqualReduce(lhs, rhs, map_free_vars, current_paths); |
507 | } |
508 | |
509 | TVM_REGISTER_GLOBAL("node.StructuralEqual" ) |
510 | .set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool assert_mode, |
511 | bool map_free_vars) { |
512 | Optional<ObjectPathPair> first_mismatch; |
513 | return SEqualHandlerDefault(assert_mode, &first_mismatch, false) |
514 | .Equal(lhs, rhs, map_free_vars); |
515 | }); |
516 | |
517 | TVM_REGISTER_GLOBAL("node.GetFirstStructuralMismatch" ) |
518 | .set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) { |
519 | Optional<ObjectPathPair> first_mismatch; |
520 | bool equal = |
521 | SEqualHandlerDefault(false, &first_mismatch, true).Equal(lhs, rhs, map_free_vars); |
522 | ICHECK(equal == !first_mismatch.defined()); |
523 | return first_mismatch; |
524 | }); |
525 | |
526 | bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const { |
527 | return SEqualHandlerDefault(false, nullptr, false).Equal(lhs, rhs, false); |
528 | } |
529 | |
530 | bool NDArrayEqual(const runtime::NDArray::Container* lhs, const runtime::NDArray::Container* rhs, |
531 | SEqualReducer equal, bool compare_data) { |
532 | if (lhs == rhs) return true; |
533 | |
534 | auto ldt = lhs->dl_tensor.dtype; |
535 | auto rdt = rhs->dl_tensor.dtype; |
536 | ICHECK_EQ(lhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor" ; |
537 | ICHECK_EQ(rhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor" ; |
538 | ICHECK(runtime::IsContiguous(lhs->dl_tensor)) << "Can only compare contiguous tensor" ; |
539 | ICHECK(runtime::IsContiguous(rhs->dl_tensor)) << "Can only compare contiguous tensor" ; |
540 | |
541 | if (lhs->dl_tensor.ndim != rhs->dl_tensor.ndim) return false; |
542 | for (int i = 0; i < lhs->dl_tensor.ndim; ++i) { |
543 | if (!equal(lhs->dl_tensor.shape[i], rhs->dl_tensor.shape[i])) return false; |
544 | } |
545 | if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { |
546 | size_t data_size = runtime::GetDataSize(lhs->dl_tensor); |
547 | if (compare_data) { |
548 | return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0; |
549 | } else { |
550 | return true; |
551 | } |
552 | } else { |
553 | return false; |
554 | } |
555 | } |
556 | |
557 | bool NDArrayContainerTrait::SEqualReduce(const runtime::NDArray::Container* lhs, |
558 | const runtime::NDArray::Container* rhs, |
559 | SEqualReducer equal) { |
560 | return NDArrayEqual(lhs, rhs, equal, true); |
561 | } |
562 | |
563 | } // namespace tvm |
564 | |