1#include <torch/csrc/lazy/core/trie.h>
2
3#include <torch/csrc/lazy/core/hash.h>
4#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
5#include <torch/csrc/lazy/core/ir_metadata.h>
6#include <torch/csrc/lazy/core/metrics.h>
7#include <fstream>
8#include <sstream>
9
10namespace torch {
11namespace lazy {
12namespace {
13
14void TraverseTrie(TrieNode* node, std::stringstream& ss) {
15 if (!node) {
16 return;
17 }
18 if (node->ir_node) {
19 ss << node->unique_id << "[label=\"" << node->ir_node->op().ToString()
20 << ", " << node->hit_counter << " hits\"]\n";
21 }
22 for (auto& successor : node->successors) {
23 ss << node->unique_id << " -> " << successor->unique_id << "\n";
24 TraverseTrie(successor.get(), ss);
25 }
26}
27} // namespace
28
29TrieCache* TrieCache::Get() {
30 static thread_local TrieCache* trie = new TrieCache();
31 return trie;
32}
33
34TrieCache::TrieCache()
35 : root_(std::make_shared<TrieNode>()), current_(root_.get()) {}
36
37TrieNode* TrieCache::Current() const {
38 return current_;
39}
40
41void TrieCache::SetCurrent(
42 std::list<std::shared_ptr<TrieNode>>::iterator& iter) {
43 auto& successors = current_->successors;
44 // Update current_ before iter gets destroyed
45 current_ = (*iter).get();
46
47 // Insert this node to the front of its parent's successor list
48 if (iter != successors.begin()) {
49 successors.push_front(std::move(*iter));
50 successors.erase(iter);
51 }
52}
53
54void TrieCache::ResetCurrent() {
55 current_ = root_.get();
56}
57
58void TrieCache::Insert(NodePtr ir_node) {
59 TORCH_CHECK(current_);
60 if (!current_->successors.empty()) {
61 TORCH_LAZY_COUNTER("TrieForked", 1);
62 }
63 auto new_node = std::make_shared<TrieNode>(std::move(ir_node));
64 current_->successors.push_front(std::move(new_node));
65 // Update current_ to the newly inserted node
66 current_ = current_->successors.front().get();
67}
68
69void TrieCache::Clear() {
70 ResetCurrent();
71 // Clear at the root level should be sufficient because all the nodes
72 // are created as shared_ptr.
73 root_->successors.clear();
74}
75
76void TrieCache::DumpToDotFile(const std::string& file_name) {
77 std::stringstream ss;
78 ss << "digraph G {\n";
79 TraverseTrie(root_.get(), ss);
80 ss << "}\n";
81
82 std::ofstream graph_file(file_name);
83 graph_file << ss.str();
84}
85
86} // namespace lazy
87} // namespace torch
88