1 | #include <torch/csrc/lazy/core/trie.h> |
---|---|
2 | |
3 | #include <torch/csrc/lazy/core/hash.h> |
4 | #include <torch/csrc/lazy/core/internal_ops/ltc_ops.h> |
5 | #include <torch/csrc/lazy/core/ir_metadata.h> |
6 | #include <torch/csrc/lazy/core/metrics.h> |
7 | #include <fstream> |
8 | #include <sstream> |
9 | |
10 | namespace torch { |
11 | namespace lazy { |
12 | namespace { |
13 | |
14 | void TraverseTrie(TrieNode* node, std::stringstream& ss) { |
15 | if (!node) { |
16 | return; |
17 | } |
18 | if (node->ir_node) { |
19 | ss << node->unique_id << "[label=\""<< node->ir_node->op().ToString() |
20 | << ", "<< node->hit_counter << " hits\"]\n"; |
21 | } |
22 | for (auto& successor : node->successors) { |
23 | ss << node->unique_id << " -> "<< successor->unique_id << "\n"; |
24 | TraverseTrie(successor.get(), ss); |
25 | } |
26 | } |
27 | } // namespace |
28 | |
29 | TrieCache* TrieCache::Get() { |
30 | static thread_local TrieCache* trie = new TrieCache(); |
31 | return trie; |
32 | } |
33 | |
34 | TrieCache::TrieCache() |
35 | : root_(std::make_shared<TrieNode>()), current_(root_.get()) {} |
36 | |
37 | TrieNode* TrieCache::Current() const { |
38 | return current_; |
39 | } |
40 | |
41 | void TrieCache::SetCurrent( |
42 | std::list<std::shared_ptr<TrieNode>>::iterator& iter) { |
43 | auto& successors = current_->successors; |
44 | // Update current_ before iter gets destroyed |
45 | current_ = (*iter).get(); |
46 | |
47 | // Insert this node to the front of its parent's successor list |
48 | if (iter != successors.begin()) { |
49 | successors.push_front(std::move(*iter)); |
50 | successors.erase(iter); |
51 | } |
52 | } |
53 | |
54 | void TrieCache::ResetCurrent() { |
55 | current_ = root_.get(); |
56 | } |
57 | |
58 | void TrieCache::Insert(NodePtr ir_node) { |
59 | TORCH_CHECK(current_); |
60 | if (!current_->successors.empty()) { |
61 | TORCH_LAZY_COUNTER("TrieForked", 1); |
62 | } |
63 | auto new_node = std::make_shared<TrieNode>(std::move(ir_node)); |
64 | current_->successors.push_front(std::move(new_node)); |
65 | // Update current_ to the newly inserted node |
66 | current_ = current_->successors.front().get(); |
67 | } |
68 | |
69 | void TrieCache::Clear() { |
70 | ResetCurrent(); |
71 | // Clear at the root level should be sufficient because all the nodes |
72 | // are created as shared_ptr. |
73 | root_->successors.clear(); |
74 | } |
75 | |
76 | void TrieCache::DumpToDotFile(const std::string& file_name) { |
77 | std::stringstream ss; |
78 | ss << "digraph G {\n"; |
79 | TraverseTrie(root_.get(), ss); |
80 | ss << "}\n"; |
81 | |
82 | std::ofstream graph_file(file_name); |
83 | graph_file << ss.str(); |
84 | } |
85 | |
86 | } // namespace lazy |
87 | } // namespace torch |
88 |