1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * \file src/node/structural_hash.cc
21 */
22#include <dmlc/memory_io.h>
23#include <tvm/node/functor.h>
24#include <tvm/node/node.h>
25#include <tvm/node/object_path.h>
26#include <tvm/node/reflection.h>
27#include <tvm/node/structural_hash.h>
28#include <tvm/runtime/container/adt.h>
29#include <tvm/runtime/profiling.h>
30#include <tvm/runtime/registry.h>
31
32#include <algorithm>
33#include <unordered_map>
34
35#include "../support/base64.h"
36#include "../support/str_escape.h"
37#include "../support/utils.h"
38#include "ndarray_hash_equal.h"
39
40namespace tvm {
41
42// Define the dispatch functio here since primary user is in this file.
43void ReflectionVTable::SHashReduce(const Object* self, SHashReducer reducer) const {
44 uint32_t tindex = self->type_index();
45 if (tindex >= fshash_reduce_.size() || fshash_reduce_[tindex] == nullptr) {
46 LOG(FATAL) << "TypeError: SHashReduce of " << self->GetTypeKey()
47 << " is not registered via TVM_REGISTER_NODE_TYPE";
48 }
49 fshash_reduce_[tindex](self, reducer);
50}
51
52// Hash handler that handles free vars
53// by assigning an unique counter in the order of their ocurrence.
54//
55// This algorithm depends on the determinism of the traversal of SHash function.
56// In particular, when we traverse unordered_map, we should first sort
57// the entries by keys(or hash of keys) before traversing.
58
59class SHashHandlerDefault::Impl {
60 public:
61 explicit Impl(SHashHandlerDefault* parent) : parent_(parent) {}
62
63 /*! \brief Pending reduce tasks. */
64 struct Task {
65 /*!
66 * \brief The object operand to be hashed.
67 * If the object is nullptr, then the reduced hash is already set
68 * the correct value.
69 */
70 ObjectRef object;
71 /*! \brief The partially reduce hash value.*/
72 size_t reduced_hash;
73 /*! \brief The expected location in the result stack. */
74 size_t result_stack_index = std::numeric_limits<size_t>::max();
75 /*! \brief Whether the children has been expanded via SEqualReduce */
76 bool children_expanded{false};
77 /*! \brief Whether the node is graph node. */
78 bool graph_node_hash{false};
79 /*! \brief whether to map the free variables. */
80 bool map_free_vars;
81
82 Task() = default;
83 explicit Task(ObjectRef object, size_t reduced_hash, bool map_free_vars)
84 : object(object), reduced_hash(reduced_hash), map_free_vars(map_free_vars) {}
85 };
86
87 void MarkGraphNode() {
88 // need to push to pending tasks in this case
89 ICHECK(!allow_push_to_stack_ && !task_stack_.empty());
90 task_stack_.back().graph_node_hash = true;
91 }
92
93 bool LookupHashedValue(const ObjectRef& key, size_t* hash_value) {
94 auto it = hash_memo_.find(key);
95 if (it != hash_memo_.end()) {
96 hash_value[0] = it->second;
97 return true;
98 }
99 return false;
100 }
101
102 void SHashReduceHashedValue(size_t hashed_value) {
103 pending_tasks_.emplace_back(Task(ObjectRef(nullptr), hashed_value, false));
104 }
105
106 void SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) {
107 ICHECK(!hash_memo_.count(GetRef<ObjectRef>(var)));
108 if (map_free_vars) {
109 // use counter value.
110 size_t value = std::hash<size_t>()(free_var_counter_++);
111 pending_tasks_.emplace_back(Task(ObjectRef(nullptr), value, false));
112 } else {
113 // use pointer hash
114 size_t value = std::hash<const runtime::Object*>()(var);
115 pending_tasks_.emplace_back(Task(ObjectRef(nullptr), value, false));
116 }
117 }
118
119 void SHashReduce(const ObjectRef& object, bool map_free_vars) {
120 // Directly push the result
121 // Note: it is still important to push the result to pendng tasks
122 // so that the reduction order of hash values stays the same.
123 if (!object.defined()) {
124 pending_tasks_.emplace_back(Task(ObjectRef(nullptr), 0, false));
125 return;
126 }
127 auto it = hash_memo_.find(object);
128 if (it != hash_memo_.end()) {
129 pending_tasks_.emplace_back(Task(ObjectRef(nullptr), it->second, false));
130 } else {
131 // Push a pending task with initial value.
132 pending_tasks_.emplace_back(Task(object, object->GetTypeKeyHash(), map_free_vars));
133 }
134 }
135
136 size_t Hash(const ObjectRef& object, bool map_free_vars) {
137 ICHECK_EQ(task_stack_.size(), 0U);
138 ICHECK_EQ(pending_tasks_.size(), 0U);
139 ICHECK_EQ(result_stack_.size(), 0U);
140
141 this->SHashReduce(object, map_free_vars);
142 ICHECK_EQ(pending_tasks_.size(), 1U);
143 ICHECK(allow_push_to_stack_);
144 task_stack_.emplace_back(std::move(pending_tasks_.back()));
145 pending_tasks_.clear();
146
147 this->RunTasks();
148
149 ICHECK_EQ(result_stack_.size(), 1U);
150 size_t ret = result_stack_.back();
151 result_stack_.pop_back();
152 return ret;
153 }
154
155 void DispatchSHash(const ObjectRef& object, bool map_free_vars) {
156 ICHECK(object.defined());
157 vtable_->SHashReduce(object.get(), SHashReducer(parent_, map_free_vars));
158 }
159
160 protected:
161 /*!
162 * \brief Pop the top entry of the task stack and push the hash into the result stack.
163 */
164 void PopTaskStack() {
165 const auto& entry = task_stack_.back();
166 result_stack_.push_back(entry.reduced_hash);
167 task_stack_.pop_back();
168 }
169 /*!
170 * \brief Compute the reduced hash value for the task.
171 * \param task The indicated task.
172 */
173 size_t ReduceHash(const Task& task) {
174 size_t stack_begin = task.result_stack_index;
175 ICHECK_LE(stack_begin, result_stack_.size());
176
177 // combine in the reverse order of the stack.
178 size_t reduced_hash = task.reduced_hash;
179 for (size_t i = result_stack_.size(); i != stack_begin; --i) {
180 reduced_hash = support::HashCombine(reduced_hash, result_stack_[i - 1]);
181 }
182 result_stack_.resize(stack_begin);
183 return reduced_hash;
184 }
185 // run the tasks.
186 void RunTasks() {
187 while (task_stack_.size() != 0) {
188 // Caution: entry becomes invalid when the stack changes
189 auto& entry = task_stack_.back();
190 if (entry.children_expanded) {
191 // reduce hash
192 entry.reduced_hash = ReduceHash(entry);
193 // When all the children has expanded and visited.
194 // entry.reduced_hash contains the reduced hash result.
195 auto it = hash_memo_.find(entry.object);
196 if (it != hash_memo_.end()) {
197 // use the pre-computed hash for the object.
198 entry.reduced_hash = it->second;
199 } else {
200 // Append the graph node counter to the hash
201 // so that we can distinguish DAG from trees.
202 if (entry.graph_node_hash) {
203 entry.reduced_hash = support::HashCombine(entry.reduced_hash,
204 std::hash<size_t>()(graph_node_counter_++));
205 }
206 hash_memo_[entry.object] = entry.reduced_hash;
207 }
208 // send value to parent.
209 this->PopTaskStack();
210 } else if (!entry.object.defined()) {
211 // Directly send value to parent
212 this->PopTaskStack();
213 } else {
214 // check if there are already hash for object.
215 auto it = hash_memo_.find(entry.object);
216 if (it != hash_memo_.end()) {
217 entry.reduced_hash = it->second;
218 this->PopTaskStack();
219 } else {
220 // NOTE: important to modify entry before visit.
221 // as entry becomes invalid after we change the stack.
222 entry.children_expanded = true;
223 entry.result_stack_index = result_stack_.size();
224
225 ICHECK_EQ(pending_tasks_.size(), 0U);
226 allow_push_to_stack_ = false;
227 // dispatch hash, reduce to the current slot.
228 parent_->DispatchSHash(entry.object, entry.map_free_vars);
229 allow_push_to_stack_ = true;
230 // Move pending tasks to the stack until the marked point.
231 while (pending_tasks_.size() != 0) {
232 task_stack_.emplace_back(std::move(pending_tasks_.back()));
233 pending_tasks_.pop_back();
234 }
235 }
236 }
237 }
238 }
239
240 private:
241 // The owner of this impl
242 SHashHandlerDefault* parent_;
243 // free var counter.
244 size_t free_var_counter_{0};
245 // graph node counter.
246 size_t graph_node_counter_{0};
247 // record current stack top
248 bool allow_push_to_stack_{true};
249 // list of pending tasks to be pushed to the stack.
250 std::vector<Task> pending_tasks_;
251 // Internal task stack to executed the task
252 std::vector<Task> task_stack_;
253 // Internal stack to store the result poped from the task stack.
254 std::vector<size_t> result_stack_;
255 // reflection vtable
256 ReflectionVTable* vtable_ = ReflectionVTable::Global();
257 // map from lhs to rhs
258 std::unordered_map<ObjectRef, size_t, ObjectPtrHash, ObjectPtrEqual> hash_memo_;
259};
260
261SHashHandlerDefault::SHashHandlerDefault() { impl = new Impl(this); }
262SHashHandlerDefault::~SHashHandlerDefault() { delete impl; }
263
264void SHashHandlerDefault::SHashReduceHashedValue(size_t hashed_value) {
265 return impl->SHashReduceHashedValue(hashed_value);
266}
267
268void SHashHandlerDefault::SHashReduce(const ObjectRef& key, bool map_free_vars) {
269 impl->SHashReduce(key, map_free_vars);
270}
271
272void SHashHandlerDefault::SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) {
273 impl->SHashReduceFreeVar(var, map_free_vars);
274}
275
276bool SHashHandlerDefault::LookupHashedValue(const ObjectRef& key, size_t* hashed_value) {
277 return impl->LookupHashedValue(key, hashed_value);
278}
279
280void SHashHandlerDefault::MarkGraphNode() { impl->MarkGraphNode(); }
281
282size_t SHashHandlerDefault::Hash(const ObjectRef& object, bool map_free_vars) {
283 return impl->Hash(object, map_free_vars);
284}
285
286void SHashHandlerDefault::DispatchSHash(const ObjectRef& key, bool map_free_vars) {
287 impl->DispatchSHash(key, map_free_vars);
288}
289
290TVM_REGISTER_GLOBAL("node.StructuralHash")
291 .set_body_typed([](const ObjectRef& object, bool map_free_vars) -> int64_t {
292 size_t hashed_value = SHashHandlerDefault().Hash(object, map_free_vars);
293 return static_cast<int64_t>(hashed_value);
294 });
295
296size_t StructuralHash::operator()(const ObjectRef& object) const {
297 return SHashHandlerDefault().Hash(object, false);
298}
299
300// SEQualReduce traits for runtime containers.
301struct StringObjTrait {
302 static constexpr const std::nullptr_t VisitAttrs = nullptr;
303
304 static void SHashReduce(const runtime::StringObj* key, SHashReducer hash_reduce) {
305 hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes(key->data, key->size));
306 }
307
308 static bool SEqualReduce(const runtime::StringObj* lhs, const runtime::StringObj* rhs,
309 SEqualReducer equal) {
310 if (lhs == rhs) return true;
311 if (lhs->size != rhs->size) return false;
312 if (lhs->data == rhs->data) return true;
313 return std::memcmp(lhs->data, rhs->data, lhs->size) == 0;
314 }
315};
316
317struct RefToObjectPtr : public ObjectRef {
318 static ObjectPtr<Object> Get(const ObjectRef& ref) { return GetDataPtr<Object>(ref); }
319};
320
321TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait)
322 .set_creator([](const std::string& bytes) {
323 return RefToObjectPtr::Get(runtime::String(bytes));
324 })
325 .set_repr_bytes([](const Object* n) -> std::string {
326 return GetRef<runtime::String>(static_cast<const runtime::StringObj*>(n))
327 .
328 operator std::string();
329 });
330
331TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
332 .set_dispatch<runtime::StringObj>([](const ObjectRef& node, ReprPrinter* p) {
333 auto* op = static_cast<const runtime::StringObj*>(node.get());
334 p->stream << '"' << support::StrEscape(op->data, op->size) << '"';
335 });
336
337struct ADTObjTrait {
338 static constexpr const std::nullptr_t VisitAttrs = nullptr;
339
340 static void SHashReduce(const runtime::ADTObj* key, SHashReducer hash_reduce) {
341 hash_reduce(key->tag);
342 hash_reduce(static_cast<uint64_t>(key->size));
343 for (uint32_t i = 0; i < key->size; ++i) {
344 hash_reduce((*key)[i]);
345 }
346 }
347
348 static bool SEqualReduce(const runtime::ADTObj* lhs, const runtime::ADTObj* rhs,
349 SEqualReducer equal) {
350 if (lhs == rhs) return true;
351 if (lhs->tag != rhs->tag) return false;
352 if (lhs->size != rhs->size) return false;
353
354 for (uint32_t i = 0; i < lhs->size; ++i) {
355 if (!equal((*lhs)[i], (*rhs)[i])) return false;
356 }
357 return true;
358 }
359};
360
361TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait);
362
363void NDArrayHash(const runtime::NDArray::Container* arr, SHashReducer* hash_reduce,
364 bool hash_data) {
365 ICHECK_EQ(arr->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor";
366 ICHECK(runtime::IsContiguous(arr->dl_tensor)) << "Can only hash contiguous tensor";
367 (*hash_reduce)(runtime::DataType(arr->dl_tensor.dtype));
368 (*hash_reduce)(arr->dl_tensor.ndim);
369 for (int i = 0; i < arr->dl_tensor.ndim; ++i) {
370 (*hash_reduce)(arr->dl_tensor.shape[i]);
371 }
372 if (hash_data) {
373 (*hash_reduce)
374 ->SHashReduceHashedValue(runtime::String::HashBytes(
375 static_cast<const char*>(arr->dl_tensor.data), runtime::GetDataSize(arr->dl_tensor)));
376 }
377}
378
379void NDArrayContainerTrait::SHashReduce(const runtime::NDArray::Container* key,
380 SHashReducer hash_reduce) {
381 NDArrayHash(key, &hash_reduce, /*bool hash_data*/ true);
382}
383
384TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait)
385 .set_creator([](const std::string& blob) {
386 dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob));
387 support::Base64InStream b64strm(&mstrm);
388 b64strm.InitPosition();
389 runtime::NDArray temp;
390 ICHECK(temp.Load(&b64strm));
391 return RefToObjectPtr::Get(temp);
392 })
393 .set_repr_bytes([](const Object* n) -> std::string {
394 std::string blob;
395 dmlc::MemoryStringStream mstrm(&blob);
396 support::Base64OutStream b64strm(&mstrm);
397 const auto* ndarray = static_cast<const runtime::NDArray::Container*>(n);
398 runtime::SaveDLTensor(&b64strm, &ndarray->dl_tensor);
399 b64strm.Finish();
400 return blob;
401 });
402
403struct ArrayNodeTrait {
404 static constexpr const std::nullptr_t VisitAttrs = nullptr;
405
406 static void SHashReduce(const ArrayNode* key, SHashReducer hash_reduce) {
407 hash_reduce(static_cast<uint64_t>(key->size()));
408 for (size_t i = 0; i < key->size(); ++i) {
409 hash_reduce(key->at(i));
410 }
411 }
412
413 static bool SEqualReduce(const ArrayNode* lhs, const ArrayNode* rhs, SEqualReducer equal) {
414 if (equal.IsPathTracingEnabled()) {
415 return SEqualReduceTraced(lhs, rhs, equal);
416 }
417
418 if (lhs->size() != rhs->size()) return false;
419 for (size_t i = 0; i < lhs->size(); ++i) {
420 if (!equal(lhs->at(i), rhs->at(i))) return false;
421 }
422 return true;
423 }
424
425 private:
426 static bool SEqualReduceTraced(const ArrayNode* lhs, const ArrayNode* rhs,
427 const SEqualReducer& equal) {
428 size_t min_size = std::min(lhs->size(), rhs->size());
429 const ObjectPathPair& array_paths = equal.GetCurrentObjectPaths();
430
431 for (size_t index = 0; index < min_size; ++index) {
432 ObjectPathPair element_paths = {array_paths->lhs_path->ArrayIndex(index),
433 array_paths->rhs_path->ArrayIndex(index)};
434 if (!equal(lhs->at(index), rhs->at(index), element_paths)) {
435 return false;
436 }
437 }
438
439 if (lhs->size() == rhs->size()) {
440 return true;
441 }
442
443 // If the array length is mismatched, don't report it immediately.
444 // Instead, defer the failure until we visit all children.
445 //
446 // This is for human readability. For example, say we have two sequences
447 //
448 // (1) a b c d e f g h i j k l m
449 // (2) a b c d e g h i j k l m
450 //
451 // If we directly report a mismatch at the end of the array right now,
452 // the user will see that array (1) has an element `m` at index 12 but array (2)
453 // has no index 12 because it's too short:
454 //
455 // (1) a b c d e f g h i j k l m
456 // ^error here
457 // (2) a b c d e g h i j k l m
458 // ^ error here
459 //
460 // This is not very helpful. Instead, if we defer reporting this mismatch until all elements
461 // are fully visited, we can be much more helpful with pointing out the location:
462 //
463 // (1) a b c d e f g h i j k l m
464 // ^
465 // error here
466 //
467 // (2) a b c d e g h i j k l m
468 // ^
469 // error here
470 if (equal->IsFailDeferralEnabled()) {
471 if (lhs->size() > min_size) {
472 equal->DeferFail({array_paths->lhs_path->ArrayIndex(min_size),
473 array_paths->rhs_path->MissingArrayElement(min_size)});
474 } else {
475 equal->DeferFail({array_paths->lhs_path->MissingArrayElement(min_size),
476 array_paths->rhs_path->ArrayIndex(min_size)});
477 }
478 // Can return `true` pretending that everything is good since we have deferred the failure.
479 return true;
480 }
481 return false;
482 }
483};
484TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait)
485 .set_creator([](const std::string&) -> ObjectPtr<Object> {
486 return ::tvm::runtime::make_object<ArrayNode>();
487 });
488
489struct ShapeTupleObjTrait {
490 static constexpr const std::nullptr_t VisitAttrs = nullptr;
491
492 static void SHashReduce(const ShapeTupleObj* self, SHashReducer hash_reduce) {
493 hash_reduce(self->size);
494 for (size_t i = 0; i < self->size; ++i) {
495 hash_reduce(self->data[i]);
496 }
497 }
498
499 static bool SEqualReduce(const ShapeTupleObj* lhs, const ShapeTupleObj* rhs,
500 SEqualReducer equal) {
501 if (lhs->size != rhs->size) return false;
502 for (size_t i = 0; i < lhs->size; ++i) {
503 if (!equal(lhs->data[i], rhs->data[i])) return false;
504 }
505 return true;
506 }
507};
508
509TVM_REGISTER_REFLECTION_VTABLE(ShapeTupleObj, ShapeTupleObjTrait)
510 .set_creator([](const std::string& blob) {
511 // Store shape tuple in blob to avoid large integer overflow in JSON.
512 dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob));
513 support::Base64InStream b64strm(&mstrm);
514 b64strm.InitPosition();
515 uint64_t size;
516 b64strm.Read<uint64_t>(&size);
517 std::vector<int64_t> data(size);
518 b64strm.ReadArray(data.data(), size);
519 ShapeTuple shape(data);
520 return RefToObjectPtr::Get(shape);
521 })
522 .set_repr_bytes([](const Object* n) -> std::string {
523 std::string blob;
524 dmlc::MemoryStringStream mstrm(&blob);
525 support::Base64OutStream b64strm(&mstrm);
526 const auto* shape = static_cast<const runtime::ShapeTupleObj*>(n);
527 b64strm.Write<uint64_t>(shape->size);
528 b64strm.WriteArray(shape->data, shape->size);
529 b64strm.Finish();
530 return blob;
531 });
532
533struct MapNodeTrait {
534 static constexpr const std::nullptr_t VisitAttrs = nullptr;
535
536 static void SHashReduceForOMap(const MapNode* key, SHashReducer hash_reduce) {
537 // SHash's var handling depends on the determinism of traversal.
538 // NOTE: only book-keep the mapped hash keys.
539 // This resolves common use cases where we want to store
540 // Map<Var, Value> where Var is defined in the function
541 // parameters.
542 using KV = std::pair<size_t, ObjectRef>;
543 std::vector<KV> temp;
544 for (const auto& kv : *key) {
545 size_t hashed_value;
546 if (hash_reduce->LookupHashedValue(kv.first, &hashed_value)) {
547 temp.emplace_back(hashed_value, kv.second);
548 }
549 }
550 // sort by the hash key of the keys.
551 std::sort(temp.begin(), temp.end(),
552 [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; });
553 // add size to the hash
554 hash_reduce(static_cast<uint64_t>(key->size()));
555 // hash the content
556 for (size_t i = 0; i < temp.size();) {
557 size_t k = i + 1;
558 for (; k < temp.size() && temp[k].first == temp[i].first; ++k) {
559 }
560 // ties are rare, but we need to skip them to make the hash determinsitic
561 if (k == i + 1) {
562 hash_reduce->SHashReduceHashedValue(temp[i].first);
563 hash_reduce(temp[i].second);
564 }
565 i = k;
566 }
567 }
568
569 static void SHashReduceForSMap(const MapNode* key, SHashReducer hash_reduce) {
570 // NOTE: only book-keep the mapped hash keys.
571 // This resolves common use cases where we want to store
572 // Map<Var, Value> where Var is defined in the function
573 // parameters.
574 using KV = std::pair<String, ObjectRef>;
575 std::vector<KV> temp;
576 for (const auto& kv : *key) {
577 temp.push_back(std::make_pair(Downcast<String>(kv.first), kv.second));
578 }
579 // sort by the hash key of the keys.
580 std::sort(temp.begin(), temp.end(),
581 [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; });
582 // NOTE: we won't have ties
583 // add size to the hash after sorting.
584 hash_reduce(static_cast<uint64_t>(key->size()));
585 // hash the content
586 for (size_t i = 0; i < temp.size(); ++i) {
587 hash_reduce(temp[i].first);
588 hash_reduce(temp[i].second);
589 }
590 }
591
592 static void SHashReduce(const MapNode* key, SHashReducer hash_reduce) {
593 bool is_str_map = std::all_of(key->begin(), key->end(), [](const auto& v) {
594 return v.first->template IsInstance<StringObj>();
595 });
596 if (is_str_map) {
597 SHashReduceForSMap(key, hash_reduce);
598 } else {
599 SHashReduceForOMap(key, hash_reduce);
600 }
601 }
602
603 static bool SEqualReduceForOMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) {
604 for (const auto& kv : *lhs) {
605 // Only allow equal checking if the keys are already mapped
606 // This resolves common use cases where we want to store
607 // Map<Var, Value> where Var is defined in the function
608 // parameters.
609 ObjectRef rhs_key = equal->MapLhsToRhs(kv.first);
610 if (!rhs_key.defined()) return false;
611 auto it = rhs->find(rhs_key);
612 if (it == rhs->end()) return false;
613 if (!equal(kv.second, it->second)) return false;
614 }
615 return true;
616 }
617
618 static bool SEqualReduceForSMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) {
619 for (const auto& kv : *lhs) {
620 auto it = rhs->find(kv.first);
621 if (it == rhs->end()) return false;
622 if (!equal(kv.second, it->second)) return false;
623 }
624 return true;
625 }
626
627 static bool IsStringMap(const MapNode* map) {
628 return std::all_of(map->begin(), map->end(),
629 [](const auto& v) { return v.first->template IsInstance<StringObj>(); });
630 }
631
632 static bool SEqualReduceTracedForOMap(const MapNode* lhs, const MapNode* rhs,
633 const SEqualReducer& equal) {
634 const ObjectPathPair& map_paths = equal.GetCurrentObjectPaths();
635
636 std::vector<const Object*> seen_rhs_keys;
637
638 // First, check that every key from `lhs` is also in `rhs`,
639 // and their values are mapped to each other.
640 for (const auto& kv : *lhs) {
641 ObjectPath lhs_path = map_paths->lhs_path->MapValue(kv.first);
642
643 ObjectRef rhs_key = equal->MapLhsToRhs(kv.first);
644 if (!rhs_key.defined()) {
645 equal.RecordMismatchPaths({lhs_path, map_paths->rhs_path->MissingMapEntry()});
646 return false;
647 }
648
649 auto it = rhs->find(rhs_key);
650 if (it == rhs->end()) {
651 equal.RecordMismatchPaths({lhs_path, map_paths->rhs_path->MissingMapEntry()});
652 return false;
653 }
654
655 if (!equal(kv.second, it->second, {lhs_path, map_paths->rhs_path->MapValue(it->first)})) {
656 return false;
657 }
658
659 seen_rhs_keys.push_back(it->first.get());
660 }
661
662 std::sort(seen_rhs_keys.begin(), seen_rhs_keys.end());
663
664 // Second, check that we have visited every `rhs` key when iterating over `lhs`.
665 for (const auto& kv : *rhs) {
666 if (!std::binary_search(seen_rhs_keys.begin(), seen_rhs_keys.end(), kv.first.get())) {
667 equal.RecordMismatchPaths(
668 {map_paths->lhs_path->MissingMapEntry(), map_paths->rhs_path->MapValue(kv.first)});
669 return false;
670 }
671 }
672
673 ICHECK(lhs->size() == rhs->size());
674 return true;
675 }
676
677 static bool SEqualReduceTracedForSMap(const MapNode* lhs, const MapNode* rhs,
678 const SEqualReducer& equal) {
679 const ObjectPathPair& map_paths = equal.GetCurrentObjectPaths();
680
681 // First, check that every key from `lhs` is also in `rhs`, and their values are equal.
682 for (const auto& kv : *lhs) {
683 ObjectPath lhs_path = map_paths->lhs_path->MapValue(kv.first);
684 auto it = rhs->find(kv.first);
685 if (it == rhs->end()) {
686 equal.RecordMismatchPaths({lhs_path, map_paths->rhs_path->MissingMapEntry()});
687 return false;
688 }
689
690 if (!equal(kv.second, it->second, {lhs_path, map_paths->rhs_path->MapValue(it->first)})) {
691 return false;
692 }
693 }
694
695 // Second, make sure every key from `rhs` is also in `lhs`.
696 for (const auto& kv : *rhs) {
697 ObjectPath rhs_path = map_paths->rhs_path->MapValue(kv.first);
698 if (!lhs->count(kv.first)) {
699 equal.RecordMismatchPaths({map_paths->lhs_path->MissingMapEntry(), rhs_path});
700 return false;
701 }
702 }
703
704 ICHECK(lhs->size() == rhs->size());
705 return true;
706 }
707
708 static bool SEqualReduceTraced(const MapNode* lhs, const MapNode* rhs,
709 const SEqualReducer& equal) {
710 if (IsStringMap(lhs)) {
711 return SEqualReduceTracedForSMap(lhs, rhs, equal);
712 } else {
713 return SEqualReduceTracedForOMap(lhs, rhs, equal);
714 }
715 }
716
717 static bool SEqualReduce(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) {
718 if (equal.IsPathTracingEnabled()) {
719 return SEqualReduceTraced(lhs, rhs, equal);
720 }
721
722 if (rhs->size() != lhs->size()) return false;
723 if (rhs->size() == 0) return true;
724 bool ls = IsStringMap(lhs);
725 bool rs = IsStringMap(rhs);
726 if (ls != rs) {
727 return false;
728 }
729 return (ls && rs) ? SEqualReduceForSMap(lhs, rhs, equal) : SEqualReduceForOMap(lhs, rhs, equal);
730 }
731};
732TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait)
733 .set_creator([](const std::string&) -> ObjectPtr<Object> { return MapNode::Empty(); });
734
735struct ReportNodeTrait {
736 static void VisitAttrs(runtime::profiling::ReportNode* report, AttrVisitor* attrs) {
737 attrs->Visit("calls", &report->calls);
738 attrs->Visit("device_metrics", &report->device_metrics);
739 attrs->Visit("configuration", &report->configuration);
740 }
741 static constexpr std::nullptr_t SEqualReduce = nullptr;
742 static constexpr std::nullptr_t SHashReduce = nullptr;
743};
744TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::ReportNode, ReportNodeTrait);
745
746TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
747 .set_dispatch<runtime::profiling::ReportNode>([](const ObjectRef& node, ReprPrinter* p) {
748 auto* op = static_cast<const runtime::profiling::ReportNode*>(node.get());
749 p->stream << op->AsTable();
750 });
751
752struct CountNodeTrait {
753 static void VisitAttrs(runtime::profiling::CountNode* n, AttrVisitor* attrs) {
754 attrs->Visit("value", &n->value);
755 }
756 static constexpr std::nullptr_t SEqualReduce = nullptr;
757 static constexpr std::nullptr_t SHashReduce = nullptr;
758};
759TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::CountNode, CountNodeTrait);
760TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
761 .set_dispatch<runtime::profiling::CountNode>([](const ObjectRef& node, ReprPrinter* p) {
762 auto* op = static_cast<const runtime::profiling::CountNode*>(node.get());
763 p->stream << op->GetTypeKey() << "(" << op->value << ")";
764 });
765struct DurationNodeTrait {
766 static void VisitAttrs(runtime::profiling::DurationNode* n, AttrVisitor* attrs) {
767 attrs->Visit("microseconds", &n->microseconds);
768 }
769 static constexpr std::nullptr_t SEqualReduce = nullptr;
770 static constexpr std::nullptr_t SHashReduce = nullptr;
771};
772TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
773 .set_dispatch<runtime::profiling::DurationNode>([](const ObjectRef& node, ReprPrinter* p) {
774 auto* op = static_cast<const runtime::profiling::DurationNode*>(node.get());
775 p->stream << op->GetTypeKey() << "(" << op->microseconds << ")";
776 });
777TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::DurationNode, DurationNodeTrait);
778struct PercentNodeTrait {
779 static void VisitAttrs(runtime::profiling::PercentNode* n, AttrVisitor* attrs) {
780 attrs->Visit("percent", &n->percent);
781 }
782 static constexpr std::nullptr_t SEqualReduce = nullptr;
783 static constexpr std::nullptr_t SHashReduce = nullptr;
784};
785TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::PercentNode, PercentNodeTrait);
786TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
787 .set_dispatch<runtime::profiling::PercentNode>([](const ObjectRef& node, ReprPrinter* p) {
788 auto* op = static_cast<const runtime::profiling::PercentNode*>(node.get());
789 p->stream << op->GetTypeKey() << "(" << op->percent << ")";
790 });
791struct RatioNodeTrait {
792 static void VisitAttrs(runtime::profiling::RatioNode* n, AttrVisitor* attrs) {
793 attrs->Visit("ratio", &n->ratio);
794 }
795 static constexpr std::nullptr_t SEqualReduce = nullptr;
796 static constexpr std::nullptr_t SHashReduce = nullptr;
797};
798TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::RatioNode, RatioNodeTrait);
799TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
800 .set_dispatch<runtime::profiling::RatioNode>([](const ObjectRef& node, ReprPrinter* p) {
801 auto* op = static_cast<const runtime::profiling::RatioNode*>(node.get());
802 p->stream << op->GetTypeKey() << "(" << op->ratio << ")";
803 });
804
805} // namespace tvm
806