1#pragma once
2
3#include <atomic>
4#include <list>
5
6#include <c10/core/ScalarType.h>
7#include <torch/csrc/lazy/core/ir.h>
8#include <torch/csrc/lazy/core/metrics.h>
9
10namespace torch {
11namespace lazy {
12
13struct TORCH_API TrieNode {
14 static size_t GetNextUniqueId() {
15 static thread_local size_t id_generator = 0;
16 return id_generator++;
17 }
18
19 size_t unique_id;
20 size_t hit_counter;
21 NodePtr ir_node;
22 std::list<std::shared_ptr<TrieNode>> successors;
23
24 TrieNode() : unique_id(GetNextUniqueId()), hit_counter(0), ir_node(nullptr) {}
25 explicit TrieNode(NodePtr node)
26 : unique_id(GetNextUniqueId()),
27 hit_counter(0),
28 ir_node(std::move(node)) {}
29};
30
31class TORCH_API TrieCache {
32 public:
33 static TrieCache* Get();
34
35 TrieNode* Current() const;
36 // Take an iterator as the input because we want to move the corresponding
37 // node in the successor list to achieve a LRU caching effect
38 void SetCurrent(std::list<std::shared_ptr<TrieNode>>::iterator& iter);
39 // Used in MarkStep to indicate the end of one tracing
40 void ResetCurrent();
41
42 // Create a new TrieNode for ir_node and insert into the TrieCache
43 void Insert(NodePtr ir_node);
44
45 // Clear all TrieCache nodes
46 // TODO: Because we don't expect user to explicitly call this function via
47 // a Python API, we may need to introduce a threshold on the size of the cache
48 // to avoid holding tensors for too long.
49 void Clear();
50
51 void DumpToDotFile(const std::string& file_name);
52
53 private:
54 TrieCache();
55
56 std::shared_ptr<TrieNode> root_;
57 TrieNode* current_;
58};
59
60template <typename T, typename... Args>
61NodePtr LookupNodeFromTrieCache(Args&&... args) {
62 auto& successors = TrieCache::Get()->Current()->successors;
63 for (auto it = successors.begin(); it != successors.end(); it++) {
64 NodePtr ir_node = (*it)->ir_node;
65 const T* concrete_node = NodeCast<T>(ir_node.get());
66 if (concrete_node &&
67 concrete_node->CanBeReused(std::forward<Args>(args)...)) {
68 TORCH_LAZY_COUNTER(
69 "IrNodeReused_" + c10::demangle((typeid(T).name())), 1);
70 (*it)->hit_counter++;
71 TrieCache::Get()->SetCurrent(it);
72 return ir_node;
73 }
74 }
75 return nullptr;
76}
77
78} // namespace lazy
79} // namespace torch
80