1#include <gtest/gtest.h>
2
3#include <c10/util/Exception.h>
4#include <torch/csrc/lazy/core/config.h>
5#include <torch/csrc/lazy/core/ir.h>
6#include <torch/csrc/lazy/core/ir_builder.h>
7#include <torch/csrc/lazy/core/ir_metadata.h>
8#include <torch/csrc/lazy/core/ir_util.h>
9#include <memory>
10
11namespace torch {
12namespace lazy {
13
14class TrieCacheNode : public Node {
15 public:
16 static OpKind ClassOpKind() {
17 return OpKind();
18 }
19
20 explicit TrieCacheNode(size_t id)
21 : Node(ClassOpKind(), /* num_outputs */ 1), id_(id), hash_(Hash(id_)) {}
22 ~TrieCacheNode() override = default;
23
24 bool CanBeReused(size_t id) const {
25 return (id_ == id);
26 }
27
28 void AddOperand(Value v) {
29 if (!v.node) {
30 return;
31 }
32 operands_as_outputs_.emplace_back(v.node.get(), v.index);
33 operands_.push_back(std::move(v.node));
34 }
35
36 hash_t hash() const override {
37 return hash_;
38 }
39 hash_t shapeHash() const override {
40 return hash_;
41 }
42
43 private:
44 size_t id_;
45 hash_t hash_;
46};
47
48TEST(TrieCacheTest, TestSinglePath) {
49 FLAGS_torch_lazy_reuse_ir = true;
50 TrieCache::Get()->Clear();
51
52 NodePtr a = ReuseOrMakeNode<TrieCacheNode>(0);
53 NodePtr b = ReuseOrMakeNode<TrieCacheNode>(1);
54 NodePtr c = ReuseOrMakeNode<TrieCacheNode>(2);
55 TrieCache::Get()->ResetCurrent(); // MarkStep
56
57 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
58 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
59 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(2).get(), c.get());
60 TrieCache::Get()->ResetCurrent(); // MarkStep
61}
62
63/*
64 * 0
65 * |
66 * 1
67 * / \
68 * 2 3
69 */
70TEST(TrieCacheTest, TestTwoPaths) {
71 FLAGS_torch_lazy_reuse_ir = true;
72 TrieCache::Get()->Clear();
73
74 NodePtr a = ReuseOrMakeNode<TrieCacheNode>(0);
75 NodePtr b = ReuseOrMakeNode<TrieCacheNode>(1);
76 NodePtr c = ReuseOrMakeNode<TrieCacheNode>(2);
77 TrieCache::Get()->ResetCurrent(); // MarkStep
78
79 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
80 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
81 NodePtr d = ReuseOrMakeNode<TrieCacheNode>(3);
82 EXPECT_NE(d.get(), c.get());
83 TrieCache::Get()->ResetCurrent(); // MarkStep
84
85 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
86 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
87 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(3).get(), d.get());
88 TrieCache::Get()->ResetCurrent(); // MarkStep
89
90 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
91 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
92 EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(2).get(), c.get());
93 TrieCache::Get()->ResetCurrent(); // MarkStep
94}
95
96} // namespace lazy
97} // namespace torch
98