1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/util/Exception.h> |
4 | #include <torch/csrc/lazy/core/config.h> |
5 | #include <torch/csrc/lazy/core/ir.h> |
6 | #include <torch/csrc/lazy/core/ir_builder.h> |
7 | #include <torch/csrc/lazy/core/ir_metadata.h> |
8 | #include <torch/csrc/lazy/core/ir_util.h> |
9 | #include <memory> |
10 | |
11 | namespace torch { |
12 | namespace lazy { |
13 | |
14 | class TrieCacheNode : public Node { |
15 | public: |
16 | static OpKind ClassOpKind() { |
17 | return OpKind(); |
18 | } |
19 | |
20 | explicit TrieCacheNode(size_t id) |
21 | : Node(ClassOpKind(), /* num_outputs */ 1), id_(id), hash_(Hash(id_)) {} |
22 | ~TrieCacheNode() override = default; |
23 | |
24 | bool CanBeReused(size_t id) const { |
25 | return (id_ == id); |
26 | } |
27 | |
28 | void AddOperand(Value v) { |
29 | if (!v.node) { |
30 | return; |
31 | } |
32 | operands_as_outputs_.emplace_back(v.node.get(), v.index); |
33 | operands_.push_back(std::move(v.node)); |
34 | } |
35 | |
36 | hash_t hash() const override { |
37 | return hash_; |
38 | } |
39 | hash_t shapeHash() const override { |
40 | return hash_; |
41 | } |
42 | |
43 | private: |
44 | size_t id_; |
45 | hash_t hash_; |
46 | }; |
47 | |
48 | TEST(TrieCacheTest, TestSinglePath) { |
49 | FLAGS_torch_lazy_reuse_ir = true; |
50 | TrieCache::Get()->Clear(); |
51 | |
52 | NodePtr a = ReuseOrMakeNode<TrieCacheNode>(0); |
53 | NodePtr b = ReuseOrMakeNode<TrieCacheNode>(1); |
54 | NodePtr c = ReuseOrMakeNode<TrieCacheNode>(2); |
55 | TrieCache::Get()->ResetCurrent(); // MarkStep |
56 | |
57 | EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get()); |
58 | EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get()); |
59 | EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(2).get(), c.get()); |
60 | TrieCache::Get()->ResetCurrent(); // MarkStep |
61 | } |
62 | |
63 | /* |
64 | * 0 |
65 | * | |
66 | * 1 |
67 | * / \ |
68 | * 2 3 |
69 | */ |
70 | TEST(TrieCacheTest, TestTwoPaths) { |
71 | FLAGS_torch_lazy_reuse_ir = true; |
72 | TrieCache::Get()->Clear(); |
73 | |
74 | NodePtr a = ReuseOrMakeNode<TrieCacheNode>(0); |
75 | NodePtr b = ReuseOrMakeNode<TrieCacheNode>(1); |
76 | NodePtr c = ReuseOrMakeNode<TrieCacheNode>(2); |
77 | TrieCache::Get()->ResetCurrent(); // MarkStep |
78 | |
79 | EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get()); |
80 | EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get()); |
81 | NodePtr d = ReuseOrMakeNode<TrieCacheNode>(3); |
82 | EXPECT_NE(d.get(), c.get()); |
83 | TrieCache::Get()->ResetCurrent(); // MarkStep |
84 | |
85 | EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get()); |
86 | EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get()); |
87 | EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(3).get(), d.get()); |
88 | TrieCache::Get()->ResetCurrent(); // MarkStep |
89 | |
90 | EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get()); |
91 | EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get()); |
92 | EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(2).get(), c.get()); |
93 | TrieCache::Get()->ResetCurrent(); // MarkStep |
94 | } |
95 | |
96 | } // namespace lazy |
97 | } // namespace torch |
98 | |