1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | /*! |
20 | * \file src/node/structural_hash.cc |
21 | */ |
22 | #include <dmlc/memory_io.h> |
23 | #include <tvm/node/functor.h> |
24 | #include <tvm/node/node.h> |
25 | #include <tvm/node/object_path.h> |
26 | #include <tvm/node/reflection.h> |
27 | #include <tvm/node/structural_hash.h> |
28 | #include <tvm/runtime/container/adt.h> |
29 | #include <tvm/runtime/profiling.h> |
30 | #include <tvm/runtime/registry.h> |
31 | |
32 | #include <algorithm> |
33 | #include <unordered_map> |
34 | |
35 | #include "../support/base64.h" |
36 | #include "../support/str_escape.h" |
37 | #include "../support/utils.h" |
38 | #include "ndarray_hash_equal.h" |
39 | |
40 | namespace tvm { |
41 | |
42 | // Define the dispatch functio here since primary user is in this file. |
43 | void ReflectionVTable::SHashReduce(const Object* self, SHashReducer reducer) const { |
44 | uint32_t tindex = self->type_index(); |
45 | if (tindex >= fshash_reduce_.size() || fshash_reduce_[tindex] == nullptr) { |
46 | LOG(FATAL) << "TypeError: SHashReduce of " << self->GetTypeKey() |
47 | << " is not registered via TVM_REGISTER_NODE_TYPE" ; |
48 | } |
49 | fshash_reduce_[tindex](self, reducer); |
50 | } |
51 | |
52 | // Hash handler that handles free vars |
53 | // by assigning an unique counter in the order of their ocurrence. |
54 | // |
55 | // This algorithm depends on the determinism of the traversal of SHash function. |
56 | // In particular, when we traverse unordered_map, we should first sort |
57 | // the entries by keys(or hash of keys) before traversing. |
58 | |
59 | class SHashHandlerDefault::Impl { |
60 | public: |
61 | explicit Impl(SHashHandlerDefault* parent) : parent_(parent) {} |
62 | |
63 | /*! \brief Pending reduce tasks. */ |
64 | struct Task { |
65 | /*! |
66 | * \brief The object operand to be hashed. |
67 | * If the object is nullptr, then the reduced hash is already set |
68 | * the correct value. |
69 | */ |
70 | ObjectRef object; |
71 | /*! \brief The partially reduce hash value.*/ |
72 | size_t reduced_hash; |
73 | /*! \brief The expected location in the result stack. */ |
74 | size_t result_stack_index = std::numeric_limits<size_t>::max(); |
75 | /*! \brief Whether the children has been expanded via SEqualReduce */ |
76 | bool children_expanded{false}; |
77 | /*! \brief Whether the node is graph node. */ |
78 | bool graph_node_hash{false}; |
79 | /*! \brief whether to map the free variables. */ |
80 | bool map_free_vars; |
81 | |
82 | Task() = default; |
83 | explicit Task(ObjectRef object, size_t reduced_hash, bool map_free_vars) |
84 | : object(object), reduced_hash(reduced_hash), map_free_vars(map_free_vars) {} |
85 | }; |
86 | |
87 | void MarkGraphNode() { |
88 | // need to push to pending tasks in this case |
89 | ICHECK(!allow_push_to_stack_ && !task_stack_.empty()); |
90 | task_stack_.back().graph_node_hash = true; |
91 | } |
92 | |
93 | bool LookupHashedValue(const ObjectRef& key, size_t* hash_value) { |
94 | auto it = hash_memo_.find(key); |
95 | if (it != hash_memo_.end()) { |
96 | hash_value[0] = it->second; |
97 | return true; |
98 | } |
99 | return false; |
100 | } |
101 | |
102 | void SHashReduceHashedValue(size_t hashed_value) { |
103 | pending_tasks_.emplace_back(Task(ObjectRef(nullptr), hashed_value, false)); |
104 | } |
105 | |
106 | void SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) { |
107 | ICHECK(!hash_memo_.count(GetRef<ObjectRef>(var))); |
108 | if (map_free_vars) { |
109 | // use counter value. |
110 | size_t value = std::hash<size_t>()(free_var_counter_++); |
111 | pending_tasks_.emplace_back(Task(ObjectRef(nullptr), value, false)); |
112 | } else { |
113 | // use pointer hash |
114 | size_t value = std::hash<const runtime::Object*>()(var); |
115 | pending_tasks_.emplace_back(Task(ObjectRef(nullptr), value, false)); |
116 | } |
117 | } |
118 | |
119 | void SHashReduce(const ObjectRef& object, bool map_free_vars) { |
120 | // Directly push the result |
121 | // Note: it is still important to push the result to pendng tasks |
122 | // so that the reduction order of hash values stays the same. |
123 | if (!object.defined()) { |
124 | pending_tasks_.emplace_back(Task(ObjectRef(nullptr), 0, false)); |
125 | return; |
126 | } |
127 | auto it = hash_memo_.find(object); |
128 | if (it != hash_memo_.end()) { |
129 | pending_tasks_.emplace_back(Task(ObjectRef(nullptr), it->second, false)); |
130 | } else { |
131 | // Push a pending task with initial value. |
132 | pending_tasks_.emplace_back(Task(object, object->GetTypeKeyHash(), map_free_vars)); |
133 | } |
134 | } |
135 | |
136 | size_t Hash(const ObjectRef& object, bool map_free_vars) { |
137 | ICHECK_EQ(task_stack_.size(), 0U); |
138 | ICHECK_EQ(pending_tasks_.size(), 0U); |
139 | ICHECK_EQ(result_stack_.size(), 0U); |
140 | |
141 | this->SHashReduce(object, map_free_vars); |
142 | ICHECK_EQ(pending_tasks_.size(), 1U); |
143 | ICHECK(allow_push_to_stack_); |
144 | task_stack_.emplace_back(std::move(pending_tasks_.back())); |
145 | pending_tasks_.clear(); |
146 | |
147 | this->RunTasks(); |
148 | |
149 | ICHECK_EQ(result_stack_.size(), 1U); |
150 | size_t ret = result_stack_.back(); |
151 | result_stack_.pop_back(); |
152 | return ret; |
153 | } |
154 | |
155 | void DispatchSHash(const ObjectRef& object, bool map_free_vars) { |
156 | ICHECK(object.defined()); |
157 | vtable_->SHashReduce(object.get(), SHashReducer(parent_, map_free_vars)); |
158 | } |
159 | |
160 | protected: |
161 | /*! |
162 | * \brief Pop the top entry of the task stack and push the hash into the result stack. |
163 | */ |
164 | void PopTaskStack() { |
165 | const auto& entry = task_stack_.back(); |
166 | result_stack_.push_back(entry.reduced_hash); |
167 | task_stack_.pop_back(); |
168 | } |
169 | /*! |
170 | * \brief Compute the reduced hash value for the task. |
171 | * \param task The indicated task. |
172 | */ |
173 | size_t ReduceHash(const Task& task) { |
174 | size_t stack_begin = task.result_stack_index; |
175 | ICHECK_LE(stack_begin, result_stack_.size()); |
176 | |
177 | // combine in the reverse order of the stack. |
178 | size_t reduced_hash = task.reduced_hash; |
179 | for (size_t i = result_stack_.size(); i != stack_begin; --i) { |
180 | reduced_hash = support::HashCombine(reduced_hash, result_stack_[i - 1]); |
181 | } |
182 | result_stack_.resize(stack_begin); |
183 | return reduced_hash; |
184 | } |
185 | // run the tasks. |
186 | void RunTasks() { |
187 | while (task_stack_.size() != 0) { |
188 | // Caution: entry becomes invalid when the stack changes |
189 | auto& entry = task_stack_.back(); |
190 | if (entry.children_expanded) { |
191 | // reduce hash |
192 | entry.reduced_hash = ReduceHash(entry); |
193 | // When all the children has expanded and visited. |
194 | // entry.reduced_hash contains the reduced hash result. |
195 | auto it = hash_memo_.find(entry.object); |
196 | if (it != hash_memo_.end()) { |
197 | // use the pre-computed hash for the object. |
198 | entry.reduced_hash = it->second; |
199 | } else { |
200 | // Append the graph node counter to the hash |
201 | // so that we can distinguish DAG from trees. |
202 | if (entry.graph_node_hash) { |
203 | entry.reduced_hash = support::HashCombine(entry.reduced_hash, |
204 | std::hash<size_t>()(graph_node_counter_++)); |
205 | } |
206 | hash_memo_[entry.object] = entry.reduced_hash; |
207 | } |
208 | // send value to parent. |
209 | this->PopTaskStack(); |
210 | } else if (!entry.object.defined()) { |
211 | // Directly send value to parent |
212 | this->PopTaskStack(); |
213 | } else { |
214 | // check if there are already hash for object. |
215 | auto it = hash_memo_.find(entry.object); |
216 | if (it != hash_memo_.end()) { |
217 | entry.reduced_hash = it->second; |
218 | this->PopTaskStack(); |
219 | } else { |
220 | // NOTE: important to modify entry before visit. |
221 | // as entry becomes invalid after we change the stack. |
222 | entry.children_expanded = true; |
223 | entry.result_stack_index = result_stack_.size(); |
224 | |
225 | ICHECK_EQ(pending_tasks_.size(), 0U); |
226 | allow_push_to_stack_ = false; |
227 | // dispatch hash, reduce to the current slot. |
228 | parent_->DispatchSHash(entry.object, entry.map_free_vars); |
229 | allow_push_to_stack_ = true; |
230 | // Move pending tasks to the stack until the marked point. |
231 | while (pending_tasks_.size() != 0) { |
232 | task_stack_.emplace_back(std::move(pending_tasks_.back())); |
233 | pending_tasks_.pop_back(); |
234 | } |
235 | } |
236 | } |
237 | } |
238 | } |
239 | |
240 | private: |
241 | // The owner of this impl |
242 | SHashHandlerDefault* parent_; |
243 | // free var counter. |
244 | size_t free_var_counter_{0}; |
245 | // graph node counter. |
246 | size_t graph_node_counter_{0}; |
247 | // record current stack top |
248 | bool allow_push_to_stack_{true}; |
249 | // list of pending tasks to be pushed to the stack. |
250 | std::vector<Task> pending_tasks_; |
251 | // Internal task stack to executed the task |
252 | std::vector<Task> task_stack_; |
253 | // Internal stack to store the result poped from the task stack. |
254 | std::vector<size_t> result_stack_; |
255 | // reflection vtable |
256 | ReflectionVTable* vtable_ = ReflectionVTable::Global(); |
257 | // map from lhs to rhs |
258 | std::unordered_map<ObjectRef, size_t, ObjectPtrHash, ObjectPtrEqual> hash_memo_; |
259 | }; |
260 | |
261 | SHashHandlerDefault::SHashHandlerDefault() { impl = new Impl(this); } |
262 | SHashHandlerDefault::~SHashHandlerDefault() { delete impl; } |
263 | |
264 | void SHashHandlerDefault::SHashReduceHashedValue(size_t hashed_value) { |
265 | return impl->SHashReduceHashedValue(hashed_value); |
266 | } |
267 | |
268 | void SHashHandlerDefault::SHashReduce(const ObjectRef& key, bool map_free_vars) { |
269 | impl->SHashReduce(key, map_free_vars); |
270 | } |
271 | |
272 | void SHashHandlerDefault::SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) { |
273 | impl->SHashReduceFreeVar(var, map_free_vars); |
274 | } |
275 | |
276 | bool SHashHandlerDefault::LookupHashedValue(const ObjectRef& key, size_t* hashed_value) { |
277 | return impl->LookupHashedValue(key, hashed_value); |
278 | } |
279 | |
280 | void SHashHandlerDefault::MarkGraphNode() { impl->MarkGraphNode(); } |
281 | |
282 | size_t SHashHandlerDefault::Hash(const ObjectRef& object, bool map_free_vars) { |
283 | return impl->Hash(object, map_free_vars); |
284 | } |
285 | |
286 | void SHashHandlerDefault::DispatchSHash(const ObjectRef& key, bool map_free_vars) { |
287 | impl->DispatchSHash(key, map_free_vars); |
288 | } |
289 | |
290 | TVM_REGISTER_GLOBAL("node.StructuralHash" ) |
291 | .set_body_typed([](const ObjectRef& object, bool map_free_vars) -> int64_t { |
292 | size_t hashed_value = SHashHandlerDefault().Hash(object, map_free_vars); |
293 | return static_cast<int64_t>(hashed_value); |
294 | }); |
295 | |
296 | size_t StructuralHash::operator()(const ObjectRef& object) const { |
297 | return SHashHandlerDefault().Hash(object, false); |
298 | } |
299 | |
300 | // SEQualReduce traits for runtime containers. |
301 | struct StringObjTrait { |
302 | static constexpr const std::nullptr_t VisitAttrs = nullptr; |
303 | |
304 | static void SHashReduce(const runtime::StringObj* key, SHashReducer hash_reduce) { |
305 | hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes(key->data, key->size)); |
306 | } |
307 | |
308 | static bool SEqualReduce(const runtime::StringObj* lhs, const runtime::StringObj* rhs, |
309 | SEqualReducer equal) { |
310 | if (lhs == rhs) return true; |
311 | if (lhs->size != rhs->size) return false; |
312 | if (lhs->data == rhs->data) return true; |
313 | return std::memcmp(lhs->data, rhs->data, lhs->size) == 0; |
314 | } |
315 | }; |
316 | |
317 | struct RefToObjectPtr : public ObjectRef { |
318 | static ObjectPtr<Object> Get(const ObjectRef& ref) { return GetDataPtr<Object>(ref); } |
319 | }; |
320 | |
321 | TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait) |
322 | .set_creator([](const std::string& bytes) { |
323 | return RefToObjectPtr::Get(runtime::String(bytes)); |
324 | }) |
325 | .set_repr_bytes([](const Object* n) -> std::string { |
326 | return GetRef<runtime::String>(static_cast<const runtime::StringObj*>(n)) |
327 | . |
328 | operator std::string(); |
329 | }); |
330 | |
331 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
332 | .set_dispatch<runtime::StringObj>([](const ObjectRef& node, ReprPrinter* p) { |
333 | auto* op = static_cast<const runtime::StringObj*>(node.get()); |
334 | p->stream << '"' << support::StrEscape(op->data, op->size) << '"'; |
335 | }); |
336 | |
337 | struct ADTObjTrait { |
338 | static constexpr const std::nullptr_t VisitAttrs = nullptr; |
339 | |
340 | static void SHashReduce(const runtime::ADTObj* key, SHashReducer hash_reduce) { |
341 | hash_reduce(key->tag); |
342 | hash_reduce(static_cast<uint64_t>(key->size)); |
343 | for (uint32_t i = 0; i < key->size; ++i) { |
344 | hash_reduce((*key)[i]); |
345 | } |
346 | } |
347 | |
348 | static bool SEqualReduce(const runtime::ADTObj* lhs, const runtime::ADTObj* rhs, |
349 | SEqualReducer equal) { |
350 | if (lhs == rhs) return true; |
351 | if (lhs->tag != rhs->tag) return false; |
352 | if (lhs->size != rhs->size) return false; |
353 | |
354 | for (uint32_t i = 0; i < lhs->size; ++i) { |
355 | if (!equal((*lhs)[i], (*rhs)[i])) return false; |
356 | } |
357 | return true; |
358 | } |
359 | }; |
360 | |
361 | TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait); |
362 | |
363 | void NDArrayHash(const runtime::NDArray::Container* arr, SHashReducer* hash_reduce, |
364 | bool hash_data) { |
365 | ICHECK_EQ(arr->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor" ; |
366 | ICHECK(runtime::IsContiguous(arr->dl_tensor)) << "Can only hash contiguous tensor" ; |
367 | (*hash_reduce)(runtime::DataType(arr->dl_tensor.dtype)); |
368 | (*hash_reduce)(arr->dl_tensor.ndim); |
369 | for (int i = 0; i < arr->dl_tensor.ndim; ++i) { |
370 | (*hash_reduce)(arr->dl_tensor.shape[i]); |
371 | } |
372 | if (hash_data) { |
373 | (*hash_reduce) |
374 | ->SHashReduceHashedValue(runtime::String::HashBytes( |
375 | static_cast<const char*>(arr->dl_tensor.data), runtime::GetDataSize(arr->dl_tensor))); |
376 | } |
377 | } |
378 | |
379 | void NDArrayContainerTrait::SHashReduce(const runtime::NDArray::Container* key, |
380 | SHashReducer hash_reduce) { |
381 | NDArrayHash(key, &hash_reduce, /*bool hash_data*/ true); |
382 | } |
383 | |
384 | TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait) |
385 | .set_creator([](const std::string& blob) { |
386 | dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob)); |
387 | support::Base64InStream b64strm(&mstrm); |
388 | b64strm.InitPosition(); |
389 | runtime::NDArray temp; |
390 | ICHECK(temp.Load(&b64strm)); |
391 | return RefToObjectPtr::Get(temp); |
392 | }) |
393 | .set_repr_bytes([](const Object* n) -> std::string { |
394 | std::string blob; |
395 | dmlc::MemoryStringStream mstrm(&blob); |
396 | support::Base64OutStream b64strm(&mstrm); |
397 | const auto* ndarray = static_cast<const runtime::NDArray::Container*>(n); |
398 | runtime::SaveDLTensor(&b64strm, &ndarray->dl_tensor); |
399 | b64strm.Finish(); |
400 | return blob; |
401 | }); |
402 | |
403 | struct ArrayNodeTrait { |
404 | static constexpr const std::nullptr_t VisitAttrs = nullptr; |
405 | |
406 | static void SHashReduce(const ArrayNode* key, SHashReducer hash_reduce) { |
407 | hash_reduce(static_cast<uint64_t>(key->size())); |
408 | for (size_t i = 0; i < key->size(); ++i) { |
409 | hash_reduce(key->at(i)); |
410 | } |
411 | } |
412 | |
413 | static bool SEqualReduce(const ArrayNode* lhs, const ArrayNode* rhs, SEqualReducer equal) { |
414 | if (equal.IsPathTracingEnabled()) { |
415 | return SEqualReduceTraced(lhs, rhs, equal); |
416 | } |
417 | |
418 | if (lhs->size() != rhs->size()) return false; |
419 | for (size_t i = 0; i < lhs->size(); ++i) { |
420 | if (!equal(lhs->at(i), rhs->at(i))) return false; |
421 | } |
422 | return true; |
423 | } |
424 | |
425 | private: |
426 | static bool SEqualReduceTraced(const ArrayNode* lhs, const ArrayNode* rhs, |
427 | const SEqualReducer& equal) { |
428 | size_t min_size = std::min(lhs->size(), rhs->size()); |
429 | const ObjectPathPair& array_paths = equal.GetCurrentObjectPaths(); |
430 | |
431 | for (size_t index = 0; index < min_size; ++index) { |
432 | ObjectPathPair element_paths = {array_paths->lhs_path->ArrayIndex(index), |
433 | array_paths->rhs_path->ArrayIndex(index)}; |
434 | if (!equal(lhs->at(index), rhs->at(index), element_paths)) { |
435 | return false; |
436 | } |
437 | } |
438 | |
439 | if (lhs->size() == rhs->size()) { |
440 | return true; |
441 | } |
442 | |
443 | // If the array length is mismatched, don't report it immediately. |
444 | // Instead, defer the failure until we visit all children. |
445 | // |
446 | // This is for human readability. For example, say we have two sequences |
447 | // |
448 | // (1) a b c d e f g h i j k l m |
449 | // (2) a b c d e g h i j k l m |
450 | // |
451 | // If we directly report a mismatch at the end of the array right now, |
452 | // the user will see that array (1) has an element `m` at index 12 but array (2) |
453 | // has no index 12 because it's too short: |
454 | // |
455 | // (1) a b c d e f g h i j k l m |
456 | // ^error here |
457 | // (2) a b c d e g h i j k l m |
458 | // ^ error here |
459 | // |
460 | // This is not very helpful. Instead, if we defer reporting this mismatch until all elements |
461 | // are fully visited, we can be much more helpful with pointing out the location: |
462 | // |
463 | // (1) a b c d e f g h i j k l m |
464 | // ^ |
465 | // error here |
466 | // |
467 | // (2) a b c d e g h i j k l m |
468 | // ^ |
469 | // error here |
470 | if (equal->IsFailDeferralEnabled()) { |
471 | if (lhs->size() > min_size) { |
472 | equal->DeferFail({array_paths->lhs_path->ArrayIndex(min_size), |
473 | array_paths->rhs_path->MissingArrayElement(min_size)}); |
474 | } else { |
475 | equal->DeferFail({array_paths->lhs_path->MissingArrayElement(min_size), |
476 | array_paths->rhs_path->ArrayIndex(min_size)}); |
477 | } |
478 | // Can return `true` pretending that everything is good since we have deferred the failure. |
479 | return true; |
480 | } |
481 | return false; |
482 | } |
483 | }; |
484 | TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait) |
485 | .set_creator([](const std::string&) -> ObjectPtr<Object> { |
486 | return ::tvm::runtime::make_object<ArrayNode>(); |
487 | }); |
488 | |
489 | struct ShapeTupleObjTrait { |
490 | static constexpr const std::nullptr_t VisitAttrs = nullptr; |
491 | |
492 | static void SHashReduce(const ShapeTupleObj* self, SHashReducer hash_reduce) { |
493 | hash_reduce(self->size); |
494 | for (size_t i = 0; i < self->size; ++i) { |
495 | hash_reduce(self->data[i]); |
496 | } |
497 | } |
498 | |
499 | static bool SEqualReduce(const ShapeTupleObj* lhs, const ShapeTupleObj* rhs, |
500 | SEqualReducer equal) { |
501 | if (lhs->size != rhs->size) return false; |
502 | for (size_t i = 0; i < lhs->size; ++i) { |
503 | if (!equal(lhs->data[i], rhs->data[i])) return false; |
504 | } |
505 | return true; |
506 | } |
507 | }; |
508 | |
509 | TVM_REGISTER_REFLECTION_VTABLE(ShapeTupleObj, ShapeTupleObjTrait) |
510 | .set_creator([](const std::string& blob) { |
511 | // Store shape tuple in blob to avoid large integer overflow in JSON. |
512 | dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob)); |
513 | support::Base64InStream b64strm(&mstrm); |
514 | b64strm.InitPosition(); |
515 | uint64_t size; |
516 | b64strm.Read<uint64_t>(&size); |
517 | std::vector<int64_t> data(size); |
518 | b64strm.ReadArray(data.data(), size); |
519 | ShapeTuple shape(data); |
520 | return RefToObjectPtr::Get(shape); |
521 | }) |
522 | .set_repr_bytes([](const Object* n) -> std::string { |
523 | std::string blob; |
524 | dmlc::MemoryStringStream mstrm(&blob); |
525 | support::Base64OutStream b64strm(&mstrm); |
526 | const auto* shape = static_cast<const runtime::ShapeTupleObj*>(n); |
527 | b64strm.Write<uint64_t>(shape->size); |
528 | b64strm.WriteArray(shape->data, shape->size); |
529 | b64strm.Finish(); |
530 | return blob; |
531 | }); |
532 | |
533 | struct MapNodeTrait { |
534 | static constexpr const std::nullptr_t VisitAttrs = nullptr; |
535 | |
536 | static void SHashReduceForOMap(const MapNode* key, SHashReducer hash_reduce) { |
537 | // SHash's var handling depends on the determinism of traversal. |
538 | // NOTE: only book-keep the mapped hash keys. |
539 | // This resolves common use cases where we want to store |
540 | // Map<Var, Value> where Var is defined in the function |
541 | // parameters. |
542 | using KV = std::pair<size_t, ObjectRef>; |
543 | std::vector<KV> temp; |
544 | for (const auto& kv : *key) { |
545 | size_t hashed_value; |
546 | if (hash_reduce->LookupHashedValue(kv.first, &hashed_value)) { |
547 | temp.emplace_back(hashed_value, kv.second); |
548 | } |
549 | } |
550 | // sort by the hash key of the keys. |
551 | std::sort(temp.begin(), temp.end(), |
552 | [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); |
553 | // add size to the hash |
554 | hash_reduce(static_cast<uint64_t>(key->size())); |
555 | // hash the content |
556 | for (size_t i = 0; i < temp.size();) { |
557 | size_t k = i + 1; |
558 | for (; k < temp.size() && temp[k].first == temp[i].first; ++k) { |
559 | } |
560 | // ties are rare, but we need to skip them to make the hash determinsitic |
561 | if (k == i + 1) { |
562 | hash_reduce->SHashReduceHashedValue(temp[i].first); |
563 | hash_reduce(temp[i].second); |
564 | } |
565 | i = k; |
566 | } |
567 | } |
568 | |
569 | static void SHashReduceForSMap(const MapNode* key, SHashReducer hash_reduce) { |
570 | // NOTE: only book-keep the mapped hash keys. |
571 | // This resolves common use cases where we want to store |
572 | // Map<Var, Value> where Var is defined in the function |
573 | // parameters. |
574 | using KV = std::pair<String, ObjectRef>; |
575 | std::vector<KV> temp; |
576 | for (const auto& kv : *key) { |
577 | temp.push_back(std::make_pair(Downcast<String>(kv.first), kv.second)); |
578 | } |
579 | // sort by the hash key of the keys. |
580 | std::sort(temp.begin(), temp.end(), |
581 | [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); |
582 | // NOTE: we won't have ties |
583 | // add size to the hash after sorting. |
584 | hash_reduce(static_cast<uint64_t>(key->size())); |
585 | // hash the content |
586 | for (size_t i = 0; i < temp.size(); ++i) { |
587 | hash_reduce(temp[i].first); |
588 | hash_reduce(temp[i].second); |
589 | } |
590 | } |
591 | |
592 | static void SHashReduce(const MapNode* key, SHashReducer hash_reduce) { |
593 | bool is_str_map = std::all_of(key->begin(), key->end(), [](const auto& v) { |
594 | return v.first->template IsInstance<StringObj>(); |
595 | }); |
596 | if (is_str_map) { |
597 | SHashReduceForSMap(key, hash_reduce); |
598 | } else { |
599 | SHashReduceForOMap(key, hash_reduce); |
600 | } |
601 | } |
602 | |
603 | static bool SEqualReduceForOMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { |
604 | for (const auto& kv : *lhs) { |
605 | // Only allow equal checking if the keys are already mapped |
606 | // This resolves common use cases where we want to store |
607 | // Map<Var, Value> where Var is defined in the function |
608 | // parameters. |
609 | ObjectRef rhs_key = equal->MapLhsToRhs(kv.first); |
610 | if (!rhs_key.defined()) return false; |
611 | auto it = rhs->find(rhs_key); |
612 | if (it == rhs->end()) return false; |
613 | if (!equal(kv.second, it->second)) return false; |
614 | } |
615 | return true; |
616 | } |
617 | |
618 | static bool SEqualReduceForSMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { |
619 | for (const auto& kv : *lhs) { |
620 | auto it = rhs->find(kv.first); |
621 | if (it == rhs->end()) return false; |
622 | if (!equal(kv.second, it->second)) return false; |
623 | } |
624 | return true; |
625 | } |
626 | |
627 | static bool IsStringMap(const MapNode* map) { |
628 | return std::all_of(map->begin(), map->end(), |
629 | [](const auto& v) { return v.first->template IsInstance<StringObj>(); }); |
630 | } |
631 | |
632 | static bool SEqualReduceTracedForOMap(const MapNode* lhs, const MapNode* rhs, |
633 | const SEqualReducer& equal) { |
634 | const ObjectPathPair& map_paths = equal.GetCurrentObjectPaths(); |
635 | |
636 | std::vector<const Object*> seen_rhs_keys; |
637 | |
638 | // First, check that every key from `lhs` is also in `rhs`, |
639 | // and their values are mapped to each other. |
640 | for (const auto& kv : *lhs) { |
641 | ObjectPath lhs_path = map_paths->lhs_path->MapValue(kv.first); |
642 | |
643 | ObjectRef rhs_key = equal->MapLhsToRhs(kv.first); |
644 | if (!rhs_key.defined()) { |
645 | equal.RecordMismatchPaths({lhs_path, map_paths->rhs_path->MissingMapEntry()}); |
646 | return false; |
647 | } |
648 | |
649 | auto it = rhs->find(rhs_key); |
650 | if (it == rhs->end()) { |
651 | equal.RecordMismatchPaths({lhs_path, map_paths->rhs_path->MissingMapEntry()}); |
652 | return false; |
653 | } |
654 | |
655 | if (!equal(kv.second, it->second, {lhs_path, map_paths->rhs_path->MapValue(it->first)})) { |
656 | return false; |
657 | } |
658 | |
659 | seen_rhs_keys.push_back(it->first.get()); |
660 | } |
661 | |
662 | std::sort(seen_rhs_keys.begin(), seen_rhs_keys.end()); |
663 | |
664 | // Second, check that we have visited every `rhs` key when iterating over `lhs`. |
665 | for (const auto& kv : *rhs) { |
666 | if (!std::binary_search(seen_rhs_keys.begin(), seen_rhs_keys.end(), kv.first.get())) { |
667 | equal.RecordMismatchPaths( |
668 | {map_paths->lhs_path->MissingMapEntry(), map_paths->rhs_path->MapValue(kv.first)}); |
669 | return false; |
670 | } |
671 | } |
672 | |
673 | ICHECK(lhs->size() == rhs->size()); |
674 | return true; |
675 | } |
676 | |
677 | static bool SEqualReduceTracedForSMap(const MapNode* lhs, const MapNode* rhs, |
678 | const SEqualReducer& equal) { |
679 | const ObjectPathPair& map_paths = equal.GetCurrentObjectPaths(); |
680 | |
681 | // First, check that every key from `lhs` is also in `rhs`, and their values are equal. |
682 | for (const auto& kv : *lhs) { |
683 | ObjectPath lhs_path = map_paths->lhs_path->MapValue(kv.first); |
684 | auto it = rhs->find(kv.first); |
685 | if (it == rhs->end()) { |
686 | equal.RecordMismatchPaths({lhs_path, map_paths->rhs_path->MissingMapEntry()}); |
687 | return false; |
688 | } |
689 | |
690 | if (!equal(kv.second, it->second, {lhs_path, map_paths->rhs_path->MapValue(it->first)})) { |
691 | return false; |
692 | } |
693 | } |
694 | |
695 | // Second, make sure every key from `rhs` is also in `lhs`. |
696 | for (const auto& kv : *rhs) { |
697 | ObjectPath rhs_path = map_paths->rhs_path->MapValue(kv.first); |
698 | if (!lhs->count(kv.first)) { |
699 | equal.RecordMismatchPaths({map_paths->lhs_path->MissingMapEntry(), rhs_path}); |
700 | return false; |
701 | } |
702 | } |
703 | |
704 | ICHECK(lhs->size() == rhs->size()); |
705 | return true; |
706 | } |
707 | |
708 | static bool SEqualReduceTraced(const MapNode* lhs, const MapNode* rhs, |
709 | const SEqualReducer& equal) { |
710 | if (IsStringMap(lhs)) { |
711 | return SEqualReduceTracedForSMap(lhs, rhs, equal); |
712 | } else { |
713 | return SEqualReduceTracedForOMap(lhs, rhs, equal); |
714 | } |
715 | } |
716 | |
717 | static bool SEqualReduce(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { |
718 | if (equal.IsPathTracingEnabled()) { |
719 | return SEqualReduceTraced(lhs, rhs, equal); |
720 | } |
721 | |
722 | if (rhs->size() != lhs->size()) return false; |
723 | if (rhs->size() == 0) return true; |
724 | bool ls = IsStringMap(lhs); |
725 | bool rs = IsStringMap(rhs); |
726 | if (ls != rs) { |
727 | return false; |
728 | } |
729 | return (ls && rs) ? SEqualReduceForSMap(lhs, rhs, equal) : SEqualReduceForOMap(lhs, rhs, equal); |
730 | } |
731 | }; |
732 | TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait) |
733 | .set_creator([](const std::string&) -> ObjectPtr<Object> { return MapNode::Empty(); }); |
734 | |
735 | struct ReportNodeTrait { |
736 | static void VisitAttrs(runtime::profiling::ReportNode* report, AttrVisitor* attrs) { |
737 | attrs->Visit("calls" , &report->calls); |
738 | attrs->Visit("device_metrics" , &report->device_metrics); |
739 | attrs->Visit("configuration" , &report->configuration); |
740 | } |
741 | static constexpr std::nullptr_t SEqualReduce = nullptr; |
742 | static constexpr std::nullptr_t SHashReduce = nullptr; |
743 | }; |
744 | TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::ReportNode, ReportNodeTrait); |
745 | |
746 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
747 | .set_dispatch<runtime::profiling::ReportNode>([](const ObjectRef& node, ReprPrinter* p) { |
748 | auto* op = static_cast<const runtime::profiling::ReportNode*>(node.get()); |
749 | p->stream << op->AsTable(); |
750 | }); |
751 | |
752 | struct CountNodeTrait { |
753 | static void VisitAttrs(runtime::profiling::CountNode* n, AttrVisitor* attrs) { |
754 | attrs->Visit("value" , &n->value); |
755 | } |
756 | static constexpr std::nullptr_t SEqualReduce = nullptr; |
757 | static constexpr std::nullptr_t SHashReduce = nullptr; |
758 | }; |
759 | TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::CountNode, CountNodeTrait); |
760 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
761 | .set_dispatch<runtime::profiling::CountNode>([](const ObjectRef& node, ReprPrinter* p) { |
762 | auto* op = static_cast<const runtime::profiling::CountNode*>(node.get()); |
763 | p->stream << op->GetTypeKey() << "(" << op->value << ")" ; |
764 | }); |
765 | struct DurationNodeTrait { |
766 | static void VisitAttrs(runtime::profiling::DurationNode* n, AttrVisitor* attrs) { |
767 | attrs->Visit("microseconds" , &n->microseconds); |
768 | } |
769 | static constexpr std::nullptr_t SEqualReduce = nullptr; |
770 | static constexpr std::nullptr_t SHashReduce = nullptr; |
771 | }; |
772 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
773 | .set_dispatch<runtime::profiling::DurationNode>([](const ObjectRef& node, ReprPrinter* p) { |
774 | auto* op = static_cast<const runtime::profiling::DurationNode*>(node.get()); |
775 | p->stream << op->GetTypeKey() << "(" << op->microseconds << ")" ; |
776 | }); |
777 | TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::DurationNode, DurationNodeTrait); |
778 | struct PercentNodeTrait { |
779 | static void VisitAttrs(runtime::profiling::PercentNode* n, AttrVisitor* attrs) { |
780 | attrs->Visit("percent" , &n->percent); |
781 | } |
782 | static constexpr std::nullptr_t SEqualReduce = nullptr; |
783 | static constexpr std::nullptr_t SHashReduce = nullptr; |
784 | }; |
785 | TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::PercentNode, PercentNodeTrait); |
786 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
787 | .set_dispatch<runtime::profiling::PercentNode>([](const ObjectRef& node, ReprPrinter* p) { |
788 | auto* op = static_cast<const runtime::profiling::PercentNode*>(node.get()); |
789 | p->stream << op->GetTypeKey() << "(" << op->percent << ")" ; |
790 | }); |
791 | struct RatioNodeTrait { |
792 | static void VisitAttrs(runtime::profiling::RatioNode* n, AttrVisitor* attrs) { |
793 | attrs->Visit("ratio" , &n->ratio); |
794 | } |
795 | static constexpr std::nullptr_t SEqualReduce = nullptr; |
796 | static constexpr std::nullptr_t SHashReduce = nullptr; |
797 | }; |
798 | TVM_REGISTER_REFLECTION_VTABLE(runtime::profiling::RatioNode, RatioNodeTrait); |
799 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
800 | .set_dispatch<runtime::profiling::RatioNode>([](const ObjectRef& node, ReprPrinter* p) { |
801 | auto* op = static_cast<const runtime::profiling::RatioNode*>(node.get()); |
802 | p->stream << op->GetTypeKey() << "(" << op->ratio << ")" ; |
803 | }); |
804 | |
805 | } // namespace tvm |
806 | |