1#include <gtest/gtest.h>
2
3#include <c10/util/Exception.h>
4#include <torch/csrc/lazy/core/cache.h>
5#include <torch/csrc/lazy/core/hash.h>
6#include <torch/csrc/lazy/core/ir.h>
7#include <torch/csrc/lazy/core/shape.h>
8#include <torch/csrc/lazy/ts_backend/ts_node.h>
9
10namespace torch {
11namespace lazy {
12
13class CacheNode : public Node {
14 public:
15 explicit CacheNode(const std::string& str)
16 : Node(OpKind(), /* num_outputs */ 1), hash_(Hash(str)), str_(str) {}
17 ~CacheNode() override = default;
18
19 const std::vector<Output>& operands() const override {
20 TORCH_INTERNAL_ASSERT(false, "Can't access operands of test node");
21 }
22
23 const Output& operand(size_t i) const override {
24 TORCH_INTERNAL_ASSERT(false, "Can't access operand[i] of test node");
25 }
26
27 hash_t hash() const override {
28 return hash_;
29 }
30 hash_t shapeHash() const override {
31 return hash_;
32 }
33
34 private:
35 hash_t hash_;
36 std::string str_;
37};
38
39TEST(CacheTest, BasicTest) {
40 std::shared_ptr<CacheNode> a = std::make_shared<CacheNode>("a");
41 std::shared_ptr<CacheNode> b = std::make_shared<CacheNode>("b");
42 std::shared_ptr<CacheNode> c = std::make_shared<CacheNode>("c");
43 Cache<hash_t, CacheNode, HashReducer> cache(2);
44
45 cache.Add(a->hash(), a);
46 EXPECT_EQ(cache.Get(a->hash()), a);
47 EXPECT_EQ(cache.Get(b->hash()), nullptr);
48 EXPECT_EQ(cache.Get(c->hash()), nullptr);
49
50 cache.Add(b->hash(), b);
51 EXPECT_EQ(cache.Get(a->hash()), a);
52 EXPECT_EQ(cache.Get(b->hash()), b);
53 EXPECT_EQ(cache.Get(c->hash()), nullptr);
54
55 cache.Add(c->hash(), c);
56 EXPECT_EQ(cache.Get(a->hash()), nullptr); // a has been evicted
57 EXPECT_EQ(cache.Get(b->hash()), b);
58 EXPECT_EQ(cache.Get(c->hash()), c);
59
60 cache.Erase(c->hash());
61 EXPECT_EQ(cache.Get(a->hash()), nullptr);
62 EXPECT_EQ(cache.Get(b->hash()), b);
63 EXPECT_EQ(cache.Get(c->hash()), nullptr); // c has been removed
64
65 cache.Clear();
66 EXPECT_EQ(cache.Get(a->hash()), nullptr);
67 EXPECT_EQ(cache.Get(b->hash()), nullptr);
68 EXPECT_EQ(cache.Get(c->hash()), nullptr);
69}
70
71class CacheNodeWithShape : public TsNode {
72 public:
73 explicit CacheNodeWithShape(const Shape& shape)
74 : TsNode(OpKind(), shape, /* num_outputs */ 1, /* seed */ 0) {}
75};
76
77TEST(CacheTest, ShapeCacheTestForDynamicShape) {
78 // enable dynamic shape
79 FLAGS_ltc_enable_dynamic_shapes = true;
80
81 CacheNodeWithShape nodes[] = {
82 CacheNodeWithShape(Shape(c10::kFloat, {2, 4})),
83 CacheNodeWithShape(Shape(c10::kFloat, {4, 2}))};
84
85 /*
86 * Make sure the cached shape for node (2, 4) is not used for node (4, 2)
87 */
88 for (auto& node : nodes) {
89 EXPECT_EQ(node.shape(), node.computeShape([&]() { return node.shape(); }));
90 }
91
92 // reset the flag
93 FLAGS_ltc_enable_dynamic_shapes = false;
94}
95
96} // namespace lazy
97} // namespace torch
98