1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/util/Exception.h> |
4 | #include <torch/csrc/lazy/core/cache.h> |
5 | #include <torch/csrc/lazy/core/hash.h> |
6 | #include <torch/csrc/lazy/core/ir.h> |
7 | #include <torch/csrc/lazy/core/shape.h> |
8 | #include <torch/csrc/lazy/ts_backend/ts_node.h> |
9 | |
10 | namespace torch { |
11 | namespace lazy { |
12 | |
13 | class CacheNode : public Node { |
14 | public: |
15 | explicit CacheNode(const std::string& str) |
16 | : Node(OpKind(), /* num_outputs */ 1), hash_(Hash(str)), str_(str) {} |
17 | ~CacheNode() override = default; |
18 | |
19 | const std::vector<Output>& operands() const override { |
20 | TORCH_INTERNAL_ASSERT(false, "Can't access operands of test node" ); |
21 | } |
22 | |
23 | const Output& operand(size_t i) const override { |
24 | TORCH_INTERNAL_ASSERT(false, "Can't access operand[i] of test node" ); |
25 | } |
26 | |
27 | hash_t hash() const override { |
28 | return hash_; |
29 | } |
30 | hash_t shapeHash() const override { |
31 | return hash_; |
32 | } |
33 | |
34 | private: |
35 | hash_t hash_; |
36 | std::string str_; |
37 | }; |
38 | |
39 | TEST(CacheTest, BasicTest) { |
40 | std::shared_ptr<CacheNode> a = std::make_shared<CacheNode>("a" ); |
41 | std::shared_ptr<CacheNode> b = std::make_shared<CacheNode>("b" ); |
42 | std::shared_ptr<CacheNode> c = std::make_shared<CacheNode>("c" ); |
43 | Cache<hash_t, CacheNode, HashReducer> cache(2); |
44 | |
45 | cache.Add(a->hash(), a); |
46 | EXPECT_EQ(cache.Get(a->hash()), a); |
47 | EXPECT_EQ(cache.Get(b->hash()), nullptr); |
48 | EXPECT_EQ(cache.Get(c->hash()), nullptr); |
49 | |
50 | cache.Add(b->hash(), b); |
51 | EXPECT_EQ(cache.Get(a->hash()), a); |
52 | EXPECT_EQ(cache.Get(b->hash()), b); |
53 | EXPECT_EQ(cache.Get(c->hash()), nullptr); |
54 | |
55 | cache.Add(c->hash(), c); |
56 | EXPECT_EQ(cache.Get(a->hash()), nullptr); // a has been evicted |
57 | EXPECT_EQ(cache.Get(b->hash()), b); |
58 | EXPECT_EQ(cache.Get(c->hash()), c); |
59 | |
60 | cache.Erase(c->hash()); |
61 | EXPECT_EQ(cache.Get(a->hash()), nullptr); |
62 | EXPECT_EQ(cache.Get(b->hash()), b); |
63 | EXPECT_EQ(cache.Get(c->hash()), nullptr); // c has been removed |
64 | |
65 | cache.Clear(); |
66 | EXPECT_EQ(cache.Get(a->hash()), nullptr); |
67 | EXPECT_EQ(cache.Get(b->hash()), nullptr); |
68 | EXPECT_EQ(cache.Get(c->hash()), nullptr); |
69 | } |
70 | |
71 | class CacheNodeWithShape : public TsNode { |
72 | public: |
73 | explicit CacheNodeWithShape(const Shape& shape) |
74 | : TsNode(OpKind(), shape, /* num_outputs */ 1, /* seed */ 0) {} |
75 | }; |
76 | |
77 | TEST(CacheTest, ShapeCacheTestForDynamicShape) { |
78 | // enable dynamic shape |
79 | FLAGS_ltc_enable_dynamic_shapes = true; |
80 | |
81 | CacheNodeWithShape nodes[] = { |
82 | CacheNodeWithShape(Shape(c10::kFloat, {2, 4})), |
83 | CacheNodeWithShape(Shape(c10::kFloat, {4, 2}))}; |
84 | |
85 | /* |
86 | * Make sure the cached shape for node (2, 4) is not used for node (4, 2) |
87 | */ |
88 | for (auto& node : nodes) { |
89 | EXPECT_EQ(node.shape(), node.computeShape([&]() { return node.shape(); })); |
90 | } |
91 | |
92 | // reset the flag |
93 | FLAGS_ltc_enable_dynamic_shapes = false; |
94 | } |
95 | |
96 | } // namespace lazy |
97 | } // namespace torch |
98 | |