1 | #pragma once |
---|---|
2 | |
3 | #include <atomic> |
4 | #include <list> |
5 | |
6 | #include <c10/core/ScalarType.h> |
7 | #include <torch/csrc/lazy/core/ir.h> |
8 | #include <torch/csrc/lazy/core/metrics.h> |
9 | |
10 | namespace torch { |
11 | namespace lazy { |
12 | |
13 | struct TORCH_API TrieNode { |
14 | static size_t GetNextUniqueId() { |
15 | static thread_local size_t id_generator = 0; |
16 | return id_generator++; |
17 | } |
18 | |
19 | size_t unique_id; |
20 | size_t hit_counter; |
21 | NodePtr ir_node; |
22 | std::list<std::shared_ptr<TrieNode>> successors; |
23 | |
24 | TrieNode() : unique_id(GetNextUniqueId()), hit_counter(0), ir_node(nullptr) {} |
25 | explicit TrieNode(NodePtr node) |
26 | : unique_id(GetNextUniqueId()), |
27 | hit_counter(0), |
28 | ir_node(std::move(node)) {} |
29 | }; |
30 | |
31 | class TORCH_API TrieCache { |
32 | public: |
33 | static TrieCache* Get(); |
34 | |
35 | TrieNode* Current() const; |
36 | // Take an iterator as the input because we want to move the corresponding |
37 | // node in the successor list to achieve a LRU caching effect |
38 | void SetCurrent(std::list<std::shared_ptr<TrieNode>>::iterator& iter); |
39 | // Used in MarkStep to indicate the end of one tracing |
40 | void ResetCurrent(); |
41 | |
42 | // Create a new TrieNode for ir_node and insert into the TrieCache |
43 | void Insert(NodePtr ir_node); |
44 | |
45 | // Clear all TrieCache nodes |
46 | // TODO: Because we don't expect user to explicitly call this function via |
47 | // a Python API, we may need to introduce a threshold on the size of the cache |
48 | // to avoid holding tensors for too long. |
49 | void Clear(); |
50 | |
51 | void DumpToDotFile(const std::string& file_name); |
52 | |
53 | private: |
54 | TrieCache(); |
55 | |
56 | std::shared_ptr<TrieNode> root_; |
57 | TrieNode* current_; |
58 | }; |
59 | |
60 | template <typename T, typename... Args> |
61 | NodePtr LookupNodeFromTrieCache(Args&&... args) { |
62 | auto& successors = TrieCache::Get()->Current()->successors; |
63 | for (auto it = successors.begin(); it != successors.end(); it++) { |
64 | NodePtr ir_node = (*it)->ir_node; |
65 | const T* concrete_node = NodeCast<T>(ir_node.get()); |
66 | if (concrete_node && |
67 | concrete_node->CanBeReused(std::forward<Args>(args)...)) { |
68 | TORCH_LAZY_COUNTER( |
69 | "IrNodeReused_"+ c10::demangle((typeid(T).name())), 1); |
70 | (*it)->hit_counter++; |
71 | TrieCache::Get()->SetCurrent(it); |
72 | return ir_node; |
73 | } |
74 | } |
75 | return nullptr; |
76 | } |
77 | |
78 | } // namespace lazy |
79 | } // namespace torch |
80 |